Skip to content

Commit c7152a7

Browse files
committed
New function at()
1 parent dc9fcf0 commit c7152a7

File tree

12 files changed

+7565
-1191
lines changed

12 files changed

+7565
-1191
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
strategy:
4949
fail-fast: false
5050
matrix:
51-
environment: [ci-py310, ci-py313]
51+
environment: [ci-py310, ci-py313, ci-backends]
5252
runs-on: [ubuntu-latest]
5353

5454
steps:

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
56+
"jax": ("https://jax.readthedocs.io/en/latest", None),
5657
}
5758

5859
nitpick_ignore = [

pixi.lock

Lines changed: 7026 additions & 1174 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = ["array-api-compat>=1.1.1"]
29+
# dependencies = ["array-api-compat>=1.10.0"] # Do not release
3030

3131
[project.optional-dependencies]
3232
tests = [
@@ -63,9 +63,11 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
6363

6464
[tool.pixi.dependencies]
6565
python = ">=3.10.15,<3.14"
66-
array-api-compat = ">=1.1.1"
66+
# array-api-compat = ">=1.10.0" # Do not release
6767

6868
[tool.pixi.pypi-dependencies]
69+
# Do not release: main at least @ gh#205
70+
array-api-compat = { git = "https://github.com/data-apis/array-api-compat.git" }
6971
array-api-extra = { path = ".", editable = true }
7072

7173
[tool.pixi.feature.lint.dependencies]
@@ -99,7 +101,9 @@ tests-cov = "pytest -v -ra --cov --cov-report=xml --cov-report=term --durations=
99101

100102
clean-vendor-compat = "rm -rf vendor_tests/array_api_compat"
101103
clean-vendor-extra = "rm -rf vendor_tests/array_api_extra"
102-
copy-vendor-compat = { cmd = "cp -r $(python -c 'import site; print(site.getsitepackages()[0])')/array_api_compat vendor_tests/", depends-on = ["clean-vendor-compat"] }
104+
copy-vendor-compat = { cmd = "cp -r $(python -c 'import site; print(site.getsitepackages()[0])')/array_api_compat vendor_tests/", depends-on = [
105+
"clean-vendor-compat",
106+
] }
103107
copy-vendor-extra = { cmd = "cp -r src/array_api_extra vendor_tests/", depends-on = ["clean-vendor-extra"] }
104108
tests-vendor = { cmd = "pytest -v vendor_tests", depends-on = ["copy-vendor-compat", "copy-vendor-extra"] }
105109

@@ -130,6 +134,35 @@ python = "~=3.10.0"
130134
[tool.pixi.feature.py313.dependencies]
131135
python = "~=3.13.0"
132136

137+
# Backends that can run on CPU-only hosts
138+
[tool.pixi.feature.backends.target.linux-64.dependencies]
139+
pytorch = "*"
140+
dask = "*"
141+
sparse = ">=0.15"
142+
jax = "*"
143+
144+
[tool.pixi.feature.backends.target.osx-arm64.dependencies]
145+
pytorch = "*"
146+
dask = "*"
147+
sparse = ">=0.15"
148+
jax = "*"
149+
150+
[tool.pixi.feature.backends.target.win-64.dependencies]
151+
# pytorch = "*" # Package unavailable on Windows
152+
dask = "*"
153+
sparse = ">=0.15"
154+
# jax = "*" # Package unavailable on Windows
155+
156+
# Backends that require a GPU host and a CUDA driver
157+
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
158+
cupy = "*"
159+
160+
[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
161+
# cupy = "*" # Package unavailable on macOSX
162+
163+
[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
164+
cupy = "*"
165+
133166
[tool.pixi.environments]
134167
default = { solve-group = "default" }
135168
lint = { features = ["lint"], solve-group = "default" }
@@ -138,7 +171,9 @@ docs = { features = ["docs"], solve-group = "default" }
138171
dev = { features = ["lint", "tests", "docs", "dev"], solve-group = "default" }
139172
ci-py310 = ["py310", "tests"]
140173
ci-py313 = ["py313", "tests"]
141-
174+
# CUDA not available on free github actions
175+
ci-backends = ["py310", "tests", "backends"]
176+
tests-backends = ["py310", "tests", "backends", "cuda-backends"]
142177

143178
# pytest
144179

@@ -195,6 +230,8 @@ reportAny = false
195230
reportExplicitAny = false
196231
# data-apis/array-api-strict#6
197232
reportUnknownMemberType = false
233+
# no array-api-compat type stubs
234+
reportUnknownVariableType = false
198235

199236

200237
# Ruff
@@ -236,6 +273,7 @@ ignore = [
236273
"PLR09", # Too many <...>
237274
"PLR2004", # Magic value used in comparison
238275
"ISC001", # Conflicts with formatter
276+
"N801", # Class name should use CapWords convention
239277
"N802", # Function name should be lowercase
240278
"N806", # Variable in function should be lowercase
241279
]
@@ -271,6 +309,7 @@ checks = [
271309
"ES01",
272310
]
273311
exclude = [ # don't report on objects that match any of these regex
312+
'.*test_at.*',
274313
'.*test_funcs.*',
275314
'.*test_utils.*',
276315
'.*test_version.*',

src/array_api_extra/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.4.1.dev0"
615

716
# pylint: disable=duplicate-code
817
__all__ = [
918
"__version__",
19+
"at",
1020
"atleast_nd",
1121
"cov",
1222
"create_diagonal",

0 commit comments

Comments
 (0)