Skip to content

Commit 0683414

Browse files
committed
ENH Test tools for jax.jit and dask
1 parent 3754e7c commit 0683414

File tree

7 files changed

+315
-2
lines changed

7 files changed

+315
-2
lines changed

docs/api-reference.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,15 @@
1717
setdiff1d
1818
sinc
1919
```
20+
21+
### Testing utilities
22+
23+
```{eval-rst}
24+
.. currentmodule:: array_api_extra.testing
25+
.. autosummary::
26+
:nosignatures:
27+
:toctree: generated
28+
29+
lazy_xp_function
30+
patch_lazy_xp_functions
31+
```

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ reportMissingImports = false
221221
reportMissingTypeStubs = false
222222
# false positives for input validation
223223
reportUnreachable = false
224+
# ruff handles this
225+
reportUnusedParameter = false
224226

225227
executionEnvironments = [
226228
{ root = "tests", reportPrivateUsage = false },
@@ -282,7 +284,10 @@ messages_control.disable = [
282284
"design", # ignore heavily opinionated design checks
283285
"fixme", # allow FIXME comments
284286
"line-too-long", # ruff handles this
287+
"unused-argument", # ruff handles this
285288
"missing-function-docstring", # numpydoc handles this
289+
"import-error", # mypy handles this
290+
"import-outside-toplevel", # optional dependencies
286291
]
287292

288293

@@ -293,6 +298,7 @@ checks = [
293298
"all", # report on all checks, except the below
294299
"EX01", # most docstrings do not need an example
295300
"SA01", # data-apis/array-api-extra#87
301+
"SA04", # Missing description for See Also cross-reference
296302
"ES01", # most docstrings do not need an extended summary
297303
]
298304
exclude = [ # don't report on objects that match any of these regex

src/array_api_extra/_lib/_testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Testing utilities.
33
44
Note that this is private API; don't expect it to be stable.
5+
See also ..testing for public testing utilities.
56
"""
67

78
from types import ModuleType

src/array_api_extra/testing.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""
2+
Public testing utilities.
3+
4+
See also _lib._testing for additional private testing utilities.
5+
"""
6+
7+
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
8+
from __future__ import annotations
9+
10+
from collections.abc import Callable, Iterable, Sequence
11+
from functools import wraps
12+
from types import ModuleType
13+
from typing import TYPE_CHECKING, Any, TypeVar, cast
14+
15+
import pytest
16+
17+
from array_api_extra._lib._utils._compat import is_dask_namespace, is_jax_namespace
18+
19+
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
20+
21+
if TYPE_CHECKING:
22+
# TODO move outside TYPE_CHECKING
23+
# depends on scikit-learn abandoning Python 3.9
24+
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
25+
from typing import ParamSpec
26+
27+
P = ParamSpec("P")
28+
else:
29+
# Sphinx hacks
30+
class P: # pylint: disable=missing-class-docstring
31+
args: tuple
32+
kwargs: dict
33+
34+
35+
T = TypeVar("T")
36+
37+
38+
def lazy_xp_function( # type: ignore[no-any-explicit]
39+
func: Callable[..., Any],
40+
*,
41+
dask_disable_compute: bool = True,
42+
jax_jit: bool = True,
43+
static_argnums: int | Sequence[int] | None = None,
44+
static_argnames: str | Iterable[str] | None = None,
45+
) -> None: # numpydoc ignore=GL07
46+
"""
47+
Tag a function to be tested for lazy backends.
48+
49+
Tag a function, which must be imported in the test module globals, so that when any
50+
tests defined in the same module are executed with `xp=jax.numpy` the function is
51+
replaced with a jitted version of itself, and when it is executed with
52+
`xp=dask.array` the function will raise if it attempts to materialize the graph.
53+
54+
This will be later expanded to provide test coverage for other lazy backends.
55+
56+
Parameters
57+
----------
58+
func : callable
59+
Function to be tested.
60+
dask_disable_compute : bool, optional
61+
Set to True to raise an error if `func` attempts to call `dask.compute()` or
62+
`dask.persist()`. This is typically inadvertently triggered by `bool()`,
63+
`float()`, and `np.asarray()`. Set to False to allow these calls, knowing that
64+
they are going to be extremely detrimental for performance.
65+
jax_jit : bool, optional
66+
Set to True to replace `func` with `jax.jit(func)` when calling the
67+
`patch_lazy_xp_functions` test helper with `xp=jax.numpy`.
68+
Set to False if `func` is only compatible with eager (non-jitted) JAX.
69+
Default: True.
70+
static_argnums : int | Sequence[int], optional
71+
Passed to jax.jit.
72+
Positional arguments to treat as static (trace- and compile-time constant).
73+
Default: infer from static_argnames using `inspect.signature(func)`.
74+
static_argnames : str | Iterable[str], optional
75+
Passed to jax.jit.
76+
Named arguments to treat as static (compile-time constant).
77+
Default: infer from static_argnums using `inspect.signature(func)`.
78+
79+
See Also
80+
--------
81+
patch_lazy_xp_functions
82+
jax.jit
83+
84+
Examples
85+
--------
86+
In `test_mymodule.py`::
87+
88+
from array_api_extra.testing import lazy_xp_function
89+
from mymodule import myfunc
90+
91+
lazy_xp_function(myfunc)
92+
93+
def test_myfunc(xp):
94+
a = xp.asarray([1, 2])
95+
# When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)`
96+
# When xp=dask.array, crash on compute() or persist()
97+
b = myfunc(a)
98+
99+
Notes
100+
-----
101+
A test function can circumvent this monkey-patching system by calling `func` an
102+
attribute of the original module. You need to sanitize your code to
103+
make sure this does not happen.
104+
105+
Example::
106+
107+
import mymodule
108+
from mymodule import myfunc
109+
110+
lazy_xp_function(myfunc)
111+
112+
def test_myfunc(xp):
113+
a = xp.asarray([1, 2])
114+
b = myfunc(a) # This is jitted when xp=jax.numpy
115+
c = mymodule.myfunc(a) # This is not
116+
"""
117+
func.dask_disable_compute = dask_disable_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
118+
if jax_jit:
119+
func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
120+
"static_argnums": static_argnums,
121+
"static_argnames": static_argnames,
122+
}
123+
124+
125+
def patch_lazy_xp_functions(
126+
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch, *, xp: ModuleType
127+
) -> None:
128+
"""
129+
Test lazy execution.
130+
131+
If `xp==jax.numpy`, search for all functions which have been tagged by
132+
`lazy_xp_function` in the globals of the module that defines the current test
133+
and wrap them with `jax.jit`. Unwrap them at the end of the test.
134+
135+
If `xp==dask.array`, wrap the functions with a decorator that disables `compute()`
136+
and `persist()`.
137+
138+
This function should be called by your library's `xp` fixture that runs tests on
139+
multiple backends::
140+
141+
@pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
142+
def xp(request, monkeypatch):
143+
patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
144+
return request.param
145+
146+
Parameters
147+
----------
148+
request : pytest.FixtureRequest
149+
Pytest fixture, as acquired by the test itself or by one of its fixtures.
150+
monkeypatch : pytest.MonkeyPatch
151+
Pytest fixture, as acquired by the test itself or by one of its fixtures.
152+
xp : module
153+
Array namespace to be tested.
154+
155+
See Also
156+
--------
157+
lazy_xp_function
158+
pytest.FixtureRequest
159+
"""
160+
globals_ = cast(dict[str, Any], request.module.__dict__) # type: ignore[no-any-explicit]
161+
162+
if is_dask_namespace(xp):
163+
for name, func in globals_.items():
164+
if getattr(func, "dask_disable_compute", False):
165+
wrapped = _dask_disable_compute(func)
166+
monkeypatch.setitem(globals_, name, wrapped)
167+
168+
elif is_jax_namespace(xp):
169+
import jax
170+
171+
for name, func in globals_.items():
172+
kwargs = cast( # type: ignore[no-any-explicit]
173+
"dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None)
174+
)
175+
176+
# suppress unused-ignore to run mypy in -e lint as well as -e dev
177+
if kwargs is not None: # type: ignore[no-untyped-call,unused-ignore]
178+
wrapped = jax.jit(func, **kwargs) # type: ignore[no-untyped-call,unused-ignore]
179+
monkeypatch.setitem(globals_, name, wrapped) # pyright: ignore[reportUnknownArgumentType]
180+
181+
182+
def _dask_disable_compute(
183+
func: Callable[P, T],
184+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
185+
"""
186+
Wrap a function to raise if it attempts to call dask.compute or dask.persist.
187+
"""
188+
import dask.config
189+
190+
def get(*args: object, **kwargs: object) -> object: # noqa: ARG001 # numpydoc ignore=PR01
191+
"""Dask scheduler which will always raise when invoked."""
192+
msg = "Called `dask.compute()` or `dask.persist()`"
193+
raise AssertionError(msg)
194+
195+
@wraps(func)
196+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
197+
with dask.config.set({"scheduler": get}):
198+
return func(*args, **kwargs)
199+
200+
return wrapper

tests/conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from array_api_extra._lib._utils._compat import array_namespace
1414
from array_api_extra._lib._utils._compat import device as get_device
1515
from array_api_extra._lib._utils._typing import Device
16+
from array_api_extra.testing import patch_lazy_xp_functions
1617

1718
T = TypeVar("T")
1819
P = ParamSpec("P")
@@ -96,7 +97,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
9697

9798

9899
@pytest.fixture
99-
def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
100+
def xp(
101+
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
102+
) -> ModuleType: # numpydoc ignore=PR01,RT03
100103
"""
101104
Parameterized fixture that iterates on all libraries.
102105
@@ -107,6 +110,9 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
107110
if library == Backend.NUMPY_READONLY:
108111
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
109112
xp = pytest.importorskip(library.value)
113+
114+
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
115+
110116
if library == Backend.JAX:
111117
import jax
112118

tests/test_testing.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66

77
from array_api_extra._lib import Backend
88
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
9+
from array_api_extra._lib._utils._compat import (
10+
array_namespace,
11+
is_dask_namespace,
12+
is_jax_namespace,
13+
)
14+
from array_api_extra._lib._utils._typing import Array
15+
from array_api_extra.testing import lazy_xp_function
916

1017
# mypy: disable-error-code=no-any-decorated
1118
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
@@ -68,3 +75,84 @@ def test_assert_close_tolerance(xp: ModuleType):
6875
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=3)
6976
with pytest.raises(AssertionError):
7077
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)
78+
79+
80+
def good_lazy(x: Array) -> Array:
81+
"""A function that behaves well in dask and jax.jit"""
82+
return x * 2.0
83+
84+
85+
def non_materializable(x: Array) -> Array:
86+
"""
87+
This function materializes the input array, so it will fail when wrapped in jax.jit
88+
and it will trigger an expensive computation in dask.
89+
"""
90+
xp = array_namespace(x)
91+
if xp.any(x < 0.0):
92+
msg = "Negative values not allowed"
93+
raise ValueError(msg)
94+
return x
95+
96+
97+
def non_materializable2(x: Array) -> Array:
98+
return non_materializable(x)
99+
100+
101+
lazy_xp_function(good_lazy)
102+
lazy_xp_function(non_materializable2)
103+
104+
105+
def test_lazy_xp_function(xp: ModuleType):
106+
x = xp.asarray([1.0, 2.0])
107+
108+
xp_assert_equal(good_lazy(x), xp.asarray([2.0, 4.0]))
109+
xp_assert_equal(non_materializable(x), xp.asarray([1.0, 2.0])) # Not wrapped
110+
111+
if is_jax_namespace(xp):
112+
with pytest.raises(
113+
TypeError, match="Attempted boolean conversion of traced array"
114+
):
115+
non_materializable2(x) # Wrapped
116+
elif is_dask_namespace(xp):
117+
with pytest.raises(
118+
AssertionError,
119+
match=r"Called `dask.compute\(\)` or `dask.persist\(\)`",
120+
):
121+
non_materializable2(x)
122+
else:
123+
xp_assert_equal(non_materializable2(x), xp.asarray([1.0, 2.0]))
124+
125+
126+
def static_params(x: Array, n: int, flag: bool = False) -> Array:
127+
"""Function with static parameters that must not be jitted"""
128+
if flag and n > 0: # This fails if n or flag are jitted arrays
129+
return x * 2.0
130+
return x * 3.0
131+
132+
133+
def static_params1(x: Array, n: int, flag: bool = False) -> Array:
134+
return static_params(x, n, flag)
135+
136+
137+
def static_params2(x: Array, n: int, flag: bool = False) -> Array:
138+
return static_params(x, n, flag)
139+
140+
141+
def static_params3(x: Array, n: int, flag: bool = False) -> Array:
142+
return static_params(x, n, flag)
143+
144+
145+
lazy_xp_function(static_params1, static_argnums=(1, 2))
146+
lazy_xp_function(static_params2, static_argnames=("n", "flag"))
147+
lazy_xp_function(static_params3, static_argnums=1, static_argnames="flag")
148+
149+
150+
@pytest.mark.parametrize("func", [static_params1, static_params2, static_params3])
151+
def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Array]): # type: ignore[no-any-explicit]
152+
x = xp.asarray([1.0, 2.0])
153+
xp_assert_equal(func(x, 1), xp.asarray([3.0, 6.0]))
154+
xp_assert_equal(func(x, 1, True), xp.asarray([2.0, 4.0]))
155+
xp_assert_equal(func(x, 1, False), xp.asarray([3.0, 6.0]))
156+
xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0]))
157+
xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0]))
158+
xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0]))

0 commit comments

Comments
 (0)