Skip to content

Commit 02c579d

Browse files
committed
jax_autojit
1 parent e23e179 commit 02c579d

File tree

6 files changed

+97
-101
lines changed

6 files changed

+97
-101
lines changed

src/array_api_extra/testing.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from __future__ import annotations
88

99
import contextlib
10-
from collections.abc import Callable, Iterable, Iterator, Sequence
10+
import enum
11+
import warnings
12+
from collections.abc import Callable, Hashable, Iterator, Sequence
1113
from functools import wraps
1214
from types import ModuleType
1315
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
@@ -37,13 +39,22 @@ def override(func: object) -> object:
3739
_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[explicit-any]
3840

3941

42+
class Deprecated(enum.Enum):
43+
"""Unique type for deprecated parameters."""
44+
45+
DEPRECATED = 1
46+
47+
48+
DEPRECATED = Deprecated.DEPRECATED
49+
50+
4051
def lazy_xp_function( # type: ignore[explicit-any]
4152
func: Callable[..., Any],
4253
*,
4354
allow_dask_compute: int = 0,
4455
jax_jit: bool = True,
45-
static_argnums: int | Sequence[int] | None = None,
46-
static_argnames: str | Iterable[str] | None = None,
56+
static_argnums: Deprecated = DEPRECATED,
57+
static_argnames: Deprecated = DEPRECATED,
4758
) -> None: # numpydoc ignore=GL07
4859
"""
4960
Tag a function to be tested on lazy backends.
@@ -79,16 +90,15 @@ def lazy_xp_function( # type: ignore[explicit-any]
7990
Default: 0, meaning that `func` must be fully lazy and never materialize the
8091
graph.
8192
jax_jit : bool, optional
82-
Set to True to replace `func` with ``jax.jit(func)`` after calling the
83-
:func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False
84-
if `func` is only compatible with eager (non-jitted) JAX. Default: True.
85-
static_argnums : int | Sequence[int], optional
86-
Passed to jax.jit. Positional arguments to treat as static (compile-time
87-
constant). Default: infer from `static_argnames` using
88-
`inspect.signature(func)`.
89-
static_argnames : str | Iterable[str], optional
90-
Passed to jax.jit. Named arguments to treat as static (compile-time constant).
91-
Default: infer from `static_argnums` using `inspect.signature(func)`.
93+
Set to True to replace `func` with a variant of ``jax.jit(func)``
94+
(read notes below) after calling the :func:`patch_lazy_xp_functions`
95+
test helper with ``xp=jax.numpy``.
96+
Set to False if `func` is only compatible with eager (non-jitted) JAX.
97+
Default: True.
98+
static_argnums :
99+
Deprecated; ignored
100+
static_argnames :
101+
Deprecated; ignored
92102
93103
See Also
94104
--------
@@ -165,12 +175,20 @@ def test_myfunc(xp):
165175
b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
166176
c = naked.myfunc(a) # This is not
167177
"""
178+
if static_argnums is not DEPRECATED or static_argnames is not DEPRECATED:
179+
warnings.warn(
180+
(
181+
"The `static_argnums` and `static_argnames` parameters are deprecated "
182+
"and ignored. They will be removed in a future version."
183+
),
184+
DeprecationWarning,
185+
stacklevel=2,
186+
)
168187
tags = {
169188
"allow_dask_compute": allow_dask_compute,
170189
"jax_jit": jax_jit,
171-
"static_argnums": static_argnums,
172-
"static_argnames": static_argnames,
173190
}
191+
174192
try:
175193
func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
176194
except AttributeError: # @cython.vectorize
@@ -240,15 +258,9 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
240258
monkeypatch.setattr(mod, name, wrapped)
241259

242260
elif is_jax_namespace(xp):
243-
import jax
244-
245261
for mod, name, func, tags in iter_tagged():
246262
if tags["jax_jit"]:
247-
wrapped = _jax_wrap(
248-
func,
249-
static_argnums=tags["static_argnums"],
250-
static_argnames=tags["static_argnames"],
251-
)
263+
wrapped = _jax_autojit(func)
252264
monkeypatch.setattr(mod, name, wrapped)
253265

254266

@@ -319,38 +331,50 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
319331
# not work on scheduler='distributed', as it would not block.
320332
pik, arrays = pickle_without(out, da.Array)
321333
arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
322-
unpickle_without(pik, arrays)
334+
return unpickle_without(pik, arrays)
323335

324336
return wrapper
325337

326338

327-
def _jax_wrap(
328-
func: Callable[P, T],
329-
static_argnums: int | Sequence[int] | None,
330-
static_argnames: str | Iterable[str] | None,
331-
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
339+
def _jax_autojit(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
332340
"""
333-
Wrap `func` inside ``jax.jit``.
334-
335-
Accepts non-array return values.
341+
Wrap `func` with ``jax.jit``, with the following differences:
342+
343+
- Array-like arguments and return values are not automatically converted to
344+
``jax.Array`` objects.
345+
- All non-array arguments are automatically treated as static.
346+
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
347+
tuple/list/dict, but can be any object serializable with ``pickle``.
348+
- Automatically descend into non-array arguments and find ``jax.Array`` objects
349+
inside them.
350+
- Automatically descend into non-array return values and find ``jax.Array`` objects
351+
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
352+
tracer objects with concrete arrays.
336353
"""
337354
import jax
338-
import jax.numpy as jnp
339355

340-
def inner(*args: P.args, **kwargs: P.kwargs) -> tuple[jax.Array, ...]: # numpydoc ignore=GL08
341-
out = func(*args, **kwargs)
342-
pik, arrays = pickle_without(out, jax.Array)
343-
return jnp.frombuffer(pik, dtype=jnp.uint8), *arrays
356+
# pickled return values of `func`, minus the JAX arrays
357+
res_piks = {}
344358

345-
jitted = jax.jit(
346-
inner, static_argnums=static_argnums, static_argnames=static_argnames
347-
)
348-
cpu = jax.devices("cpu")[0]
359+
def jit_cache_key(args_pik: bytes, *arg_arrays: jax.Array) -> Hashable: # type: ignore[no-any-unimported]
360+
return (args_pik, *((arr.shape, arr.dtype) for arr in arg_arrays))
361+
362+
def inner(args_pik: bytes, *arg_arrays: jax.Array) -> list[jax.Array]: # type: ignore[no-any-unimported] # numpydoc ignore=GL08
363+
args, kwargs = unpickle_without(args_pik, arg_arrays)
364+
res = func(*args, **kwargs)
365+
res_pik, res_arrays = pickle_without(res, jax.Array)
366+
key = jit_cache_key(args_pik, *arg_arrays)
367+
prev = res_piks.setdefault(key, res_pik)
368+
assert prev == res_pik, "cache key collision"
369+
return res_arrays
370+
371+
jitted = jax.jit(inner, static_argnums=0)
349372

350373
@wraps(func)
351374
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
352-
arrays = jitted(*args, **kwargs)
353-
pik = bytes(arrays[0].to_device(cpu))
354-
return unpickle_without(pik, arrays[1:])
375+
args_pik, arg_arrays = pickle_without((args, kwargs), jax.Array)
376+
res_arrays = jitted(args_pik, *arg_arrays)
377+
res_pik = res_piks[jit_cache_key(args_pik, *arg_arrays)]
378+
return unpickle_without(res_pik, res_arrays)
355379

356380
return outer

tests/test_at.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
import pickle
32
from collections.abc import Callable, Generator
43
from contextlib import contextmanager
54
from types import ModuleType
@@ -41,28 +40,11 @@ def at_op(
4140
just a workaround for when one wants to apply jax.jit to `at()` directly,
4241
which is not a common use case.
4342
"""
44-
if isinstance(idx, (slice | tuple)):
45-
return _at_op(x, None, pickle.dumps(idx), op, y, copy=copy, xp=xp)
46-
return _at_op(x, idx, None, op, y, copy=copy, xp=xp)
47-
48-
49-
def _at_op(
50-
x: Array,
51-
idx: SetIndex | None,
52-
idx_pickle: bytes | None,
53-
op: _AtOp,
54-
y: Array | object,
55-
copy: bool | None,
56-
xp: ModuleType | None = None,
57-
) -> Array:
58-
"""jitted helper of at_op"""
59-
if idx_pickle:
60-
idx = pickle.loads(idx_pickle)
61-
meth = cast(Callable[..., Array], getattr(at(x, cast(SetIndex, idx)), op.value)) # type: ignore[explicit-any]
43+
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[explicit-any]
6244
return meth(y, copy=copy, xp=xp)
6345

6446

65-
lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp"))
47+
lazy_xp_function(at_op)
6648

6749

6850
@contextmanager

tests/test_funcs.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@
3737
# some xp backends are untyped
3838
# mypy: disable-error-code=no-untyped-def
3939

40-
lazy_xp_function(apply_where, static_argnums=(2, 3), static_argnames="xp")
41-
lazy_xp_function(atleast_nd, static_argnames=("ndim", "xp"))
42-
lazy_xp_function(cov, static_argnames="xp")
43-
lazy_xp_function(create_diagonal, static_argnames=("offset", "xp"))
44-
lazy_xp_function(expand_dims, static_argnames=("axis", "xp"))
45-
lazy_xp_function(kron, static_argnames="xp")
46-
lazy_xp_function(nunique, static_argnames="xp")
47-
lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp"))
40+
lazy_xp_function(apply_where)
41+
lazy_xp_function(atleast_nd)
42+
lazy_xp_function(cov)
43+
lazy_xp_function(create_diagonal)
44+
lazy_xp_function(expand_dims)
45+
lazy_xp_function(kron)
46+
lazy_xp_function(nunique)
47+
lazy_xp_function(pad)
4848
# FIXME calls in1d which calls xp.unique_values without size
49-
lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp"))
50-
lazy_xp_function(sinc, static_argnames="xp")
49+
lazy_xp_function(setdiff1d, jax_jit=False)
50+
lazy_xp_function(sinc)
5151

5252

5353
class TestApplyWhere:

tests/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# mypy: disable-error-code=no-untyped-usage
2525

2626
# FIXME calls xp.unique_values without size
27-
lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))
27+
lazy_xp_function(in1d, jax_jit=False)
2828

2929

3030
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse")

tests/test_lazy.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
from array_api_extra._lib._utils._typing import Array, Device
1616
from array_api_extra.testing import lazy_xp_function
1717

18-
lazy_xp_function(
19-
lazy_apply, static_argnames=("func", "shape", "dtype", "as_numpy", "xp")
20-
)
18+
lazy_xp_function(lazy_apply)
2119

2220
as_numpy = pytest.mark.parametrize(
2321
"as_numpy",
@@ -386,7 +384,7 @@ def eager(
386384
)
387385

388386

389-
lazy_xp_function(check_lazy_apply_kwargs, static_argnames=("expect_cls", "as_numpy"))
387+
lazy_xp_function(check_lazy_apply_kwargs)
390388

391389

392390
@as_numpy

tests/test_testing.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -179,32 +179,24 @@ def static_params(x: Array, n: int, flag: bool = False) -> Array:
179179
return x * 3.0
180180

181181

182-
def static_params1(x: Array, n: int, flag: bool = False) -> Array:
183-
return static_params(x, n, flag)
182+
lazy_xp_function(static_params)
184183

185184

186-
def static_params2(x: Array, n: int, flag: bool = False) -> Array:
187-
return static_params(x, n, flag)
188-
189-
190-
def static_params3(x: Array, n: int, flag: bool = False) -> Array:
191-
return static_params(x, n, flag)
192-
193-
194-
lazy_xp_function(static_params1, static_argnums=(1, 2))
195-
lazy_xp_function(static_params2, static_argnames=("n", "flag"))
196-
lazy_xp_function(static_params3, static_argnums=1, static_argnames="flag")
197-
198-
199-
@pytest.mark.parametrize("func", [static_params1, static_params2, static_params3])
200-
def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Array]): # type: ignore[explicit-any]
185+
def test_lazy_xp_function_static_params(xp: ModuleType):
201186
x = xp.asarray([1.0, 2.0])
202-
xp_assert_equal(func(x, 1), xp.asarray([3.0, 6.0]))
203-
xp_assert_equal(func(x, 1, True), xp.asarray([2.0, 4.0]))
204-
xp_assert_equal(func(x, 1, False), xp.asarray([3.0, 6.0]))
205-
xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0]))
206-
xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0]))
207-
xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0]))
187+
xp_assert_equal(static_params(x, 1), xp.asarray([3.0, 6.0]))
188+
xp_assert_equal(static_params(x, 1, True), xp.asarray([2.0, 4.0]))
189+
xp_assert_equal(static_params(x, 1, False), xp.asarray([3.0, 6.0]))
190+
xp_assert_equal(static_params(x, 0, False), xp.asarray([3.0, 6.0]))
191+
xp_assert_equal(static_params(x, 1, flag=True), xp.asarray([2.0, 4.0]))
192+
xp_assert_equal(static_params(x, n=1, flag=True), xp.asarray([2.0, 4.0]))
193+
194+
195+
def test_lazy_xp_function_deprecated_static_argnames():
196+
with pytest.warns(DeprecationWarning, match="static_argnames"):
197+
lazy_xp_function(static_params, static_argnames=["flag"]) # type: ignore[arg-type]
198+
with pytest.warns(DeprecationWarning, match="static_argnums"):
199+
lazy_xp_function(static_params, static_argnums=[1]) # type: ignore[arg-type]
208200

209201

210202
try:

0 commit comments

Comments
 (0)