Skip to content

Commit 197a523

Browse files
committed
refactor
1 parent 9ce1427 commit 197a523

File tree

2 files changed

+85
-75
lines changed

2 files changed

+85
-75
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import io
66
import math
77
import pickle
8-
from collections.abc import Generator, Iterable
8+
from collections.abc import Callable, Generator, Hashable, Iterable
9+
from functools import wraps
910
from types import ModuleType, NoneType
10-
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
11+
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeVar, cast
1112

1213
from . import _compat
1314
from ._compat import (
@@ -29,6 +30,7 @@ def override(func):
2930
return func
3031

3132

33+
P = ParamSpec("P")
3234
T = TypeVar("T")
3335

3436

@@ -38,6 +40,7 @@ def override(func):
3840
"eager_shape",
3941
"in1d",
4042
"is_python_scalar",
43+
"jax_autojit",
4144
"mean",
4245
"meta_namespace",
4346
"pickle_without",
@@ -368,7 +371,7 @@ def pickle_without(
368371
>>> class A:
369372
... def __repr__(self):
370373
... return "<A>"
371-
>>> obj = {1: A(), 2: [A(), NS(), A()]} # Any serializable object
374+
>>> obj = {1: A(), 2: [A(), NS(), A()]}
372375
>>> pik, instances, unpickleable = pickle_without(obj, A)
373376
>>> instances, unpickleable
374377
([<A>, <A>, <A>], [<NS>])
@@ -396,7 +399,6 @@ class Pickler(pickle.Pickler): # numpydoc ignore=GL01,RT01
396399

397400
@override
398401
def persistent_id(self, obj: object) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
399-
400402
# Fast exit in case of basic builtin types.
401403
# Note that basic collections (tuple, list, etc.) are in this;
402404
# persistent_id() will be called again with their contents.
@@ -416,7 +418,9 @@ def persistent_id(self, obj: object) -> object: # pyright: ignore[reportIncompa
416418
return id_, 0
417419

418420
try:
419-
_ = obj.__reduce__()
421+
# a class that defines __slots__ without defining __getstate__
422+
# cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
423+
_ = obj.__reduce_ex__(pickle.HIGHEST_PROTOCOL)
420424
except Exception: # pylint: disable=broad-exception-caught
421425
pass
422426
else: # Can be pickled
@@ -425,7 +429,7 @@ def persistent_id(self, obj: object) -> object: # pyright: ignore[reportIncompa
425429

426430
# May be a global function, which can be pickled
427431
try:
428-
_ = pickle.dumps(obj)
432+
_ = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
429433
except Exception: # pylint: disable=broad-exception-caught
430434
pass
431435
else: # Can be pickled
@@ -438,7 +442,7 @@ def persistent_id(self, obj: object) -> object: # pyright: ignore[reportIncompa
438442
return id_, 1
439443

440444
f = io.BytesIO()
441-
p = Pickler(f)
445+
p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL)
442446
p.dump(obj)
443447
return f.getvalue(), tuple(instances), tuple(unpickleable)
444448

@@ -480,7 +484,7 @@ def unpickle_without( # type: ignore[explicit-any]
480484
quietly ignored.
481485
"""
482486
iters = iter(instances), iter(unpickleable)
483-
seen: dict[int, object] = {}
487+
seen: dict[tuple[int, int], object] = {}
484488

485489
class Unpickler(pickle.Unpickler): # numpydoc ignore=GL01,RT01
486490
"""
@@ -509,3 +513,72 @@ def persistent_load(self, pid: tuple[int, int]) -> object: # pyright: ignore[re
509513

510514
f = io.BytesIO(pik)
511515
return Unpickler(f).load()
516+
517+
518+
def jax_autojit(
519+
func: Callable[P, T],
520+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03
521+
"""
522+
Wrap `func` with ``jax.jit``, with the following differences:
523+
524+
- Array-like arguments and return values are not automatically converted to
525+
``jax.Array`` objects.
526+
- All non-array arguments are automatically treated as static.
527+
Unlike ``jax.jit``, static arguments must be either hashable or serializable with
528+
``pickle``.
529+
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
530+
tuple/list/dict, but can be any object serializable with ``pickle``.
531+
- Automatically descend into non-array arguments and find ``jax.Array`` objects
532+
inside them, then rebuild the arguments when entering `func`, swapping the JAX
533+
concrete arrays with tracer objects.
534+
- Automatically descend into non-array return values and find ``jax.Array`` objects
535+
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
536+
tracer objects with concrete arrays.
537+
"""
538+
import jax
539+
540+
# {
541+
# jit_cache_key(args_pik, args_arrays, args_unpickleable):
542+
# (res_pik, res_unpickleable)
543+
# }
544+
static_return_values: dict[Hashable, tuple[bytes, tuple[object, ...]]] = {}
545+
546+
def jit_cache_key( # type: ignore[no-any-unimported] # numpydoc ignore=GL08
547+
args_pik: bytes,
548+
args_arrays: tuple[jax.Array, ...], # pyright: ignore[reportUnknownParameterType]
549+
args_unpickleable: tuple[Hashable, ...],
550+
) -> Hashable:
551+
return (
552+
args_pik,
553+
tuple((arr.shape, arr.dtype) for arr in args_arrays), # pyright: ignore[reportUnknownArgumentType]
554+
args_unpickleable,
555+
)
556+
557+
def inner( # type: ignore[no-any-unimported] # pyright: ignore[reportUnknownParameterType]
558+
args_pik: bytes,
559+
args_arrays: tuple[jax.Array, ...], # pyright: ignore[reportUnknownParameterType]
560+
args_unpickleable: tuple[Hashable, ...],
561+
) -> tuple[jax.Array, ...]: # numpydoc ignore=GL08
562+
args, kwargs = unpickle_without(args_pik, args_arrays, args_unpickleable) # pyright: ignore[reportUnknownArgumentType]
563+
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
564+
res_pik, res_arrays, res_unpickleable = pickle_without(res, jax.Array) # pyright: ignore[reportUnknownArgumentType]
565+
key = jit_cache_key(args_pik, args_arrays, args_unpickleable)
566+
val = res_pik, res_unpickleable
567+
prev = static_return_values.setdefault(key, val)
568+
assert prev == val, "cache key collision"
569+
return res_arrays
570+
571+
jitted = jax.jit(inner, static_argnums=(0, 2))
572+
573+
@wraps(func)
574+
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
575+
args_pik, args_arrays, args_unpickleable = pickle_without(
576+
(args, kwargs),
577+
jax.Array, # pyright: ignore[reportUnknownArgumentType]
578+
)
579+
res_arrays = jitted(args_pik, args_arrays, args_unpickleable)
580+
key = jit_cache_key(args_pik, args_arrays, args_unpickleable)
581+
res_pik, res_unpickleable = static_return_values[key]
582+
return unpickle_without(res_pik, res_arrays, res_unpickleable) # pyright: ignore[reportUnknownArgumentType]
583+
584+
return outer

src/array_api_extra/testing.py

Lines changed: 4 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
import contextlib
1010
import enum
1111
import warnings
12-
from collections.abc import Callable, Hashable, Iterator, Sequence
12+
from collections.abc import Callable, Iterator, Sequence
1313
from functools import wraps
1414
from types import ModuleType
1515
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
1616

1717
from ._lib._utils._compat import is_dask_namespace, is_jax_namespace
18-
from ._lib._utils._helpers import pickle_without, unpickle_without
18+
from ._lib._utils._helpers import jax_autojit, pickle_without, unpickle_without
1919

2020
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
2121

@@ -29,7 +29,7 @@
2929
# Sphinx hacks
3030
SchedulerGetCallable = object
3131

32-
def override(func: object) -> object:
32+
def override(func):
3333
return func
3434

3535

@@ -260,7 +260,7 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
260260
elif is_jax_namespace(xp):
261261
for mod, name, func, tags in iter_tagged():
262262
if tags["jax_jit"]:
263-
wrapped = _jax_autojit(func)
263+
wrapped = jax_autojit(func)
264264
monkeypatch.setattr(mod, name, wrapped)
265265

266266

@@ -334,66 +334,3 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
334334
return unpickle_without(pik, arrays, unpickleable) # pyright: ignore[reportUnknownArgumentType]
335335

336336
return wrapper
337-
338-
339-
def _jax_autojit(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
340-
"""
341-
Wrap `func` with ``jax.jit``, with the following differences:
342-
343-
- Array-like arguments and return values are not automatically converted to
344-
``jax.Array`` objects.
345-
- All non-array arguments are automatically treated as static.
346-
Unlike ``jax.jit``, static arguments must be either hashable or serializable with
347-
``pickle``.
348-
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
349-
tuple/list/dict, but can be any object serializable with ``pickle``.
350-
- Automatically descend into non-array arguments and find ``jax.Array`` objects
351-
inside them.
352-
- Automatically descend into non-array return values and find ``jax.Array`` objects
353-
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
354-
tracer objects with concrete arrays.
355-
"""
356-
import jax
357-
358-
# pickled return values of `func`, minus the JAX arrays
359-
static_return_values: dict[Hashable, tuple[bytes, tuple[object, ...]]] = {}
360-
361-
def jit_cache_key( # type: ignore[no-any-unimported]
362-
args_pik: bytes,
363-
args_arrays: tuple[jax.Array, ...], # pyright: ignore[reportUnknownParameterType]
364-
args_unpickleable: tuple[Hashable, ...],
365-
) -> Hashable:
366-
return (
367-
args_pik,
368-
tuple((arr.shape, arr.dtype) for arr in args_arrays), # pyright: ignore[reportUnknownArgumentType]
369-
args_unpickleable,
370-
)
371-
372-
def inner( # type: ignore[no-any-unimported] # pyright: ignore[reportUnknownParameterType]
373-
args_pik: bytes,
374-
args_arrays: tuple[jax.Array, ...], # pyright: ignore[reportUnknownParameterType]
375-
args_unpickleable: tuple[Hashable, ...],
376-
) -> tuple[jax.Array, ...]: # numpydoc ignore=GL08
377-
args, kwargs = unpickle_without(args_pik, args_arrays, args_unpickleable) # pyright: ignore[reportUnknownArgumentType]
378-
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
379-
res_pik, res_arrays, res_unpickleable = pickle_without(res, jax.Array) # pyright: ignore[reportUnknownArgumentType]
380-
key = jit_cache_key(args_pik, args_arrays, args_unpickleable)
381-
val = res_pik, res_unpickleable
382-
prev = static_return_values.setdefault(key, val)
383-
assert prev == val, "cache key collision"
384-
return res_arrays
385-
386-
jitted = jax.jit(inner, static_argnums=(0, 2))
387-
388-
@wraps(func)
389-
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
390-
args_pik, args_arrays, args_unpickleable = pickle_without(
391-
(args, kwargs),
392-
jax.Array, # pyright: ignore[reportUnknownArgumentType]
393-
)
394-
res_arrays = jitted(args_pik, args_arrays, args_unpickleable)
395-
key = jit_cache_key(args_pik, args_arrays, args_unpickleable)
396-
res_pik, res_unpickleable = static_return_values[key]
397-
return unpickle_without(res_pik, res_arrays, res_unpickleable) # pyright: ignore[reportUnknownArgumentType]
398-
399-
return outer

0 commit comments

Comments
 (0)