|
7 | 7 | from __future__ import annotations |
8 | 8 |
|
9 | 9 | import contextlib |
10 | | -from collections.abc import Callable, Iterable, Iterator, Sequence |
| 10 | +import enum |
| 11 | +import warnings |
| 12 | +from collections.abc import Callable, Hashable, Iterator, Sequence |
11 | 13 | from functools import wraps |
12 | 14 | from types import ModuleType |
13 | 15 | from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast |
@@ -37,13 +39,22 @@ def override(func: object) -> object: |
37 | 39 | _ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[explicit-any] |
38 | 40 |
|
39 | 41 |
|
| 42 | +class Deprecated(enum.Enum): |
| 43 | + """Unique type for deprecated parameters.""" |
| 44 | + |
| 45 | + DEPRECATED = 1 |
| 46 | + |
| 47 | + |
| 48 | +DEPRECATED = Deprecated.DEPRECATED |
| 49 | + |
| 50 | + |
40 | 51 | def lazy_xp_function( # type: ignore[explicit-any] |
41 | 52 | func: Callable[..., Any], |
42 | 53 | *, |
43 | 54 | allow_dask_compute: int = 0, |
44 | 55 | jax_jit: bool = True, |
45 | | - static_argnums: int | Sequence[int] | None = None, |
46 | | - static_argnames: str | Iterable[str] | None = None, |
| 56 | + static_argnums: Deprecated = DEPRECATED, |
| 57 | + static_argnames: Deprecated = DEPRECATED, |
47 | 58 | ) -> None: # numpydoc ignore=GL07 |
48 | 59 | """ |
49 | 60 | Tag a function to be tested on lazy backends. |
@@ -79,16 +90,15 @@ def lazy_xp_function( # type: ignore[explicit-any] |
79 | 90 | Default: 0, meaning that `func` must be fully lazy and never materialize the |
80 | 91 | graph. |
81 | 92 | jax_jit : bool, optional |
82 | | - Set to True to replace `func` with ``jax.jit(func)`` after calling the |
83 | | - :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False |
84 | | - if `func` is only compatible with eager (non-jitted) JAX. Default: True. |
85 | | - static_argnums : int | Sequence[int], optional |
86 | | - Passed to jax.jit. Positional arguments to treat as static (compile-time |
87 | | - constant). Default: infer from `static_argnames` using |
88 | | - `inspect.signature(func)`. |
89 | | - static_argnames : str | Iterable[str], optional |
90 | | - Passed to jax.jit. Named arguments to treat as static (compile-time constant). |
91 | | - Default: infer from `static_argnums` using `inspect.signature(func)`. |
| 93 | + Set to True to replace `func` with a variant of ``jax.jit(func)`` |
| 94 | + (read notes below) after calling the :func:`patch_lazy_xp_functions` |
| 95 | + test helper with ``xp=jax.numpy``. |
| 96 | + Set to False if `func` is only compatible with eager (non-jitted) JAX. |
| 97 | + Default: True. |
| 98 | + static_argnums : |
| 99 | + Deprecated; ignored |
| 100 | + static_argnames : |
| 101 | + Deprecated; ignored |
92 | 102 |
|
93 | 103 | See Also |
94 | 104 | -------- |
@@ -165,12 +175,20 @@ def test_myfunc(xp): |
165 | 175 | b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array |
166 | 176 | c = naked.myfunc(a) # This is not |
167 | 177 | """ |
| 178 | + if static_argnums is not DEPRECATED or static_argnames is not DEPRECATED: |
| 179 | + warnings.warn( |
| 180 | + ( |
| 181 | + "The `static_argnums` and `static_argnames` parameters are deprecated " |
| 182 | + "and ignored. They will be removed in a future version." |
| 183 | + ), |
| 184 | + DeprecationWarning, |
| 185 | + stacklevel=2, |
| 186 | + ) |
168 | 187 | tags = { |
169 | 188 | "allow_dask_compute": allow_dask_compute, |
170 | 189 | "jax_jit": jax_jit, |
171 | | - "static_argnums": static_argnums, |
172 | | - "static_argnames": static_argnames, |
173 | 190 | } |
| 191 | + |
174 | 192 | try: |
175 | 193 | func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess] |
176 | 194 | except AttributeError: # @cython.vectorize |
@@ -240,15 +258,9 @@ def iter_tagged() -> ( # type: ignore[explicit-any] |
240 | 258 | monkeypatch.setattr(mod, name, wrapped) |
241 | 259 |
|
242 | 260 | elif is_jax_namespace(xp): |
243 | | - import jax |
244 | | - |
245 | 261 | for mod, name, func, tags in iter_tagged(): |
246 | 262 | if tags["jax_jit"]: |
247 | | - wrapped = _jax_wrap( |
248 | | - func, |
249 | | - static_argnums=tags["static_argnums"], |
250 | | - static_argnames=tags["static_argnames"], |
251 | | - ) |
| 263 | + wrapped = _jax_autojit(func) |
252 | 264 | monkeypatch.setattr(mod, name, wrapped) |
253 | 265 |
|
254 | 266 |
|
@@ -319,38 +331,50 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 |
319 | 331 | # not work on scheduler='distributed', as it would not block. |
320 | 332 | pik, arrays = pickle_without(out, da.Array) |
321 | 333 | arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] |
322 | | - unpickle_without(pik, arrays) |
| 334 | + return unpickle_without(pik, arrays) |
323 | 335 |
|
324 | 336 | return wrapper |
325 | 337 |
|
326 | 338 |
|
327 | | -def _jax_wrap( |
328 | | - func: Callable[P, T], |
329 | | - static_argnums: int | Sequence[int] | None, |
330 | | - static_argnames: str | Iterable[str] | None, |
331 | | -) -> Callable[P, T]: # numpydoc ignore=PR01,RT01 |
| 339 | +def _jax_autojit(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01 |
332 | 340 | """ |
333 | | - Wrap `func` inside ``jax.jit``. |
334 | | -
|
335 | | - Accepts non-array return values. |
| 341 | + Wrap `func` with ``jax.jit``, with the following differences: |
| 342 | +
|
| 343 | + - Array-like arguments and return values are not automatically converted to |
| 344 | + ``jax.Array`` objects. |
| 345 | + - All non-array arguments are automatically treated as static. |
| 346 | + - Unlike ``jax.jit``, non-array arguments and return values are not limited to |
| 347 | + tuple/list/dict, but can be any object serializable with ``pickle``. |
| 348 | + - Automatically descend into non-array arguments and find ``jax.Array`` objects |
| 349 | + inside them. |
| 350 | + - Automatically descend into non-array return values and find ``jax.Array`` objects |
| 351 | + inside them, then rebuild them downstream of exiting the JIT, swapping the JAX |
| 352 | + tracer objects with concrete arrays. |
336 | 353 | """ |
337 | 354 | import jax |
338 | | - import jax.numpy as jnp |
339 | 355 |
|
340 | | - def inner(*args: P.args, **kwargs: P.kwargs) -> tuple[jax.Array, ...]: # numpydoc ignore=GL08 |
341 | | - out = func(*args, **kwargs) |
342 | | - pik, arrays = pickle_without(out, jax.Array) |
343 | | - return jnp.frombuffer(pik, dtype=jnp.uint8), *arrays |
| 356 | + # pickled return values of `func`, minus the JAX arrays |
| 357 | + res_piks = {} |
344 | 358 |
|
345 | | - jitted = jax.jit( |
346 | | - inner, static_argnums=static_argnums, static_argnames=static_argnames |
347 | | - ) |
348 | | - cpu = jax.devices("cpu")[0] |
| 359 | + def jit_cache_key(args_pik: bytes, *arg_arrays: jax.Array) -> Hashable: # type: ignore[no-any-unimported] |
| 360 | + return (args_pik, *((arr.shape, arr.dtype) for arr in arg_arrays)) |
| 361 | + |
| 362 | + def inner(args_pik: bytes, *arg_arrays: jax.Array) -> list[jax.Array]: # type: ignore[no-any-unimported] # numpydoc ignore=GL08 |
| 363 | + args, kwargs = unpickle_without(args_pik, arg_arrays) |
| 364 | + res = func(*args, **kwargs) |
| 365 | + res_pik, res_arrays = pickle_without(res, jax.Array) |
| 366 | + key = jit_cache_key(args_pik, *arg_arrays) |
| 367 | + prev = res_piks.setdefault(key, res_pik) |
| 368 | + assert prev == res_pik, "cache key collision" |
| 369 | + return res_arrays |
| 370 | + |
| 371 | + jitted = jax.jit(inner, static_argnums=0) |
349 | 372 |
|
350 | 373 | @wraps(func) |
351 | 374 | def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 |
352 | | - arrays = jitted(*args, **kwargs) |
353 | | - pik = bytes(arrays[0].to_device(cpu)) |
354 | | - return unpickle_without(pik, arrays[1:]) |
| 375 | + args_pik, arg_arrays = pickle_without((args, kwargs), jax.Array) |
| 376 | + res_arrays = jitted(args_pik, *arg_arrays) |
| 377 | + res_pik = res_piks[jit_cache_key(args_pik, *arg_arrays)] |
| 378 | + return unpickle_without(res_pik, res_arrays) |
355 | 379 |
|
356 | 380 | return outer |
0 commit comments