Skip to content

Commit edc27db

Browse files
committed
ENH Test tools for jax.jit and dask
1 parent 3754e7c commit edc27db

File tree

11 files changed

+460
-27
lines changed

11 files changed

+460
-27
lines changed

docs/api-reference.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,15 @@
1717
setdiff1d
1818
sinc
1919
```
20+
21+
## Testing utilities
22+
23+
```{eval-rst}
24+
.. currentmodule:: array_api_extra.testing
25+
.. autosummary::
26+
:nosignatures:
27+
:toctree: generated
28+
29+
lazy_xp_function
30+
patch_lazy_xp_functions
31+
```

pixi.lock

Lines changed: 13 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ furo = ">=2023.08.17"
9898
myst-parser = ">=0.13"
9999
sphinx-copybutton = "*"
100100
sphinx-autodoc-typehints = "*"
101+
# Needed to import parsed modules with autodoc
102+
pytest = "*"
101103

102104
[tool.pixi.feature.docs.tasks]
103105
docs = { cmd = "sphinx-build . build/", cwd = "docs" }
@@ -180,8 +182,10 @@ markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific b
180182

181183
[tool.coverage]
182184
run.source = ["array_api_extra"]
183-
report.exclude_also = ['\.\.\.']
184-
185+
report.exclude_also = [
186+
'\.\.\.',
187+
'if TYPE_CHECKING:',
188+
]
185189

186190
# mypy
187191

@@ -221,6 +225,8 @@ reportMissingImports = false
221225
reportMissingTypeStubs = false
222226
# false positives for input validation
223227
reportUnreachable = false
228+
# ruff handles this
229+
reportUnusedParameter = false
224230

225231
executionEnvironments = [
226232
{ root = "tests", reportPrivateUsage = false },
@@ -282,7 +288,10 @@ messages_control.disable = [
282288
"design", # ignore heavily opinionated design checks
283289
"fixme", # allow FIXME comments
284290
"line-too-long", # ruff handles this
291+
"unused-argument", # ruff handles this
285292
"missing-function-docstring", # numpydoc handles this
293+
"import-error", # mypy handles this
294+
"import-outside-toplevel", # optional dependencies
286295
]
287296

288297

src/array_api_extra/_lib/_at.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import operator
77
from collections.abc import Callable
88
from enum import Enum
9+
from numbers import Number
910
from types import ModuleType
1011
from typing import ClassVar, cast
1112

@@ -188,7 +189,7 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
188189
def _update_common(
189190
self,
190191
at_op: _AtOp,
191-
y: Array,
192+
y: Array | Number,
192193
/,
193194
copy: bool | None,
194195
xp: ModuleType | None,
@@ -253,7 +254,7 @@ def _update_common(
253254

254255
def set(
255256
self,
256-
y: Array,
257+
y: Array | Number,
257258
/,
258259
copy: bool | None = None,
259260
xp: ModuleType | None = None,
@@ -269,8 +270,8 @@ def set(
269270
def _iop(
270271
self,
271272
at_op: _AtOp,
272-
elwise_op: Callable[[Array, Array], Array],
273-
y: Array,
273+
elwise_op: Callable[[Array, Array | Number], Array],
274+
y: Array | Number,
274275
/,
275276
copy: bool | None,
276277
xp: ModuleType | None,
@@ -294,7 +295,7 @@ def _iop(
294295

295296
def add(
296297
self,
297-
y: Array,
298+
y: Array | Number,
298299
/,
299300
copy: bool | None = None,
300301
xp: ModuleType | None = None,
@@ -308,7 +309,7 @@ def add(
308309

309310
def subtract(
310311
self,
311-
y: Array,
312+
y: Array | Number,
312313
/,
313314
copy: bool | None = None,
314315
xp: ModuleType | None = None,
@@ -318,7 +319,7 @@ def subtract(
318319

319320
def multiply(
320321
self,
321-
y: Array,
322+
y: Array | Number,
322323
/,
323324
copy: bool | None = None,
324325
xp: ModuleType | None = None,
@@ -328,7 +329,7 @@ def multiply(
328329

329330
def divide(
330331
self,
331-
y: Array,
332+
y: Array | Number,
332333
/,
333334
copy: bool | None = None,
334335
xp: ModuleType | None = None,
@@ -338,7 +339,7 @@ def divide(
338339

339340
def power(
340341
self,
341-
y: Array,
342+
y: Array | Number,
342343
/,
343344
copy: bool | None = None,
344345
xp: ModuleType | None = None,
@@ -348,7 +349,7 @@ def power(
348349

349350
def min(
350351
self,
351-
y: Array,
352+
y: Array | Number,
352353
/,
353354
copy: bool | None = None,
354355
xp: ModuleType | None = None,
@@ -361,7 +362,7 @@ def min(
361362

362363
def max(
363364
self,
364-
y: Array,
365+
y: Array | Number,
365366
/,
366367
copy: bool | None = None,
367368
xp: ModuleType | None = None,

src/array_api_extra/_lib/_testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Testing utilities.
33
44
Note that this is private API; don't expect it to be stable.
5+
See also ..testing for public testing utilities.
56
"""
67

78
from types import ModuleType

0 commit comments

Comments
 (0)