Skip to content

Commit a9bcb55

Browse files
committed
jax_autojit
1 parent 43bcb26 commit a9bcb55

File tree

8 files changed

+403
-203
lines changed

8 files changed

+403
-203
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 216 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
from __future__ import annotations
44

5-
import copyreg
65
import io
76
import math
87
import pickle
9-
from collections.abc import Callable, Generator, Iterable, Iterator
10-
from contextvars import ContextVar
11-
from types import ModuleType
12-
from typing import TYPE_CHECKING, Any, TypeVar, cast
8+
from collections.abc import Callable, Generator, Hashable, Iterable
9+
from functools import wraps
10+
from types import ModuleType, NoneType
11+
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeVar, cast
1312

1413
from . import _compat
1514
from ._compat import (
@@ -23,9 +22,15 @@
2322
from ._typing import Array
2423

2524
if TYPE_CHECKING: # pragma: no cover
26-
# TODO import from typing (requires Python >=3.13)
27-
from typing_extensions import TypeIs
25+
# TODO import from typing (requires Python >=3.12 and >=3.13)
26+
from typing_extensions import TypeIs, override
27+
else:
2828

29+
def override(func):
30+
return func
31+
32+
33+
P = ParamSpec("P")
2934
T = TypeVar("T")
3035

3136

@@ -35,6 +40,7 @@
3540
"eager_shape",
3641
"in1d",
3742
"is_python_scalar",
43+
"jax_autojit",
3844
"mean",
3945
"meta_namespace",
4046
"pickle_without",
@@ -316,48 +322,39 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
316322
return out
317323

318324

319-
# Helper of ``extract_objects`` and ``repack_objects``
320-
_repacking_objects: ContextVar[Iterator[object]] = ContextVar("_repacking_objects")
321-
322-
323-
def _expand() -> object: # numpydoc ignore=RT01
324-
"""
325-
Helper of ``extract_objects`` and ``repack_objects``.
326-
327-
Inverse of the reducer function.
328-
329-
Notes
330-
-----
331-
This function must be global in order to be picklable.
332-
"""
333-
try:
334-
return next(_repacking_objects.get())
335-
except StopIteration:
336-
msg = "Not enough objects to repack"
337-
raise ValueError(msg)
325+
_BASIC_TYPES = frozenset((
326+
NoneType, bool, int, float, complex, str, bytes, bytearray,
327+
list, tuple, dict, set, frozenset, range, slice,
328+
)) # fmt: skip
338329

339330

340-
def pickle_without(obj: object, *classes: type[T]) -> tuple[bytes, list[T]]:
331+
def pickle_without(
332+
obj: object, cls: type[T] | tuple[type[T], ...] = ()
333+
) -> tuple[bytes, tuple[T, ...], tuple[object, ...]]:
341334
"""
342-
Variant of ``pickle.dumps`` that extracts inner objects.
335+
Variant of ``pickle.dumps`` that always succeeds and extracts inner objects.
343336
344337
Conceptually, this is similar to passing the ``buffer_callback`` argument to
345-
``pickle.dumps``, but instead of extracting buffers it extracts entire objects.
338+
``pickle.dumps``, but instead of extracting buffers it extracts entire objects,
339+
which are either not serializable with ``pickle`` (e.g. local classes or functions)
340+
or instances of an explicit list of classes.
346341
347342
Parameters
348343
----------
349344
obj : object
350345
The object to pickle.
351-
*classes : type
352-
One or more classes to extract from the object.
346+
cls : type | tuple[type, ...], optional
347+
One or multiple classes to extract from the object.
353348
The instances of these classes inside ``obj`` will not be pickled.
354349
355350
Returns
356351
-------
357352
bytes
358353
The pickled object. Must be unpickled with :func:`unpickle_without`.
359-
list
360-
All instances of ``classes`` found inside ``obj`` (not pickled).
354+
tuple
355+
All instances of ``cls`` found inside ``obj`` (not pickled).
356+
tuple
357+
All other objects which failed to pickle.
361358
362359
See Also
363360
--------
@@ -366,75 +363,221 @@ def pickle_without(obj: object, *classes: type[T]) -> tuple[bytes, list[T]]:
366363
367364
Examples
368365
--------
366+
>>> class NS:
367+
... def __repr__(self):
368+
... return "<NS>"
369+
... def __reduce__(self):
370+
... assert False
369371
>>> class A:
370372
... def __repr__(self):
371373
... return "<A>"
372-
... def __reduce__(self):
373-
... assert False, "Not serializable"
374-
>>> obj = {1: A(), 2: [A(), A()]} # Any serializable object
375-
>>> pik, extracted = pickle_without(obj, A)
376-
>>> extracted
377-
[<A>, <A>, <A>]
378-
>>> unpickle_without(pik, extracted)
379-
{1: <A>, 2: [<A>, <A>]}
374+
>>> obj = {1: A(), 2: [A(), NS(), A()]}
375+
>>> pik, instances, unpickleable = pickle_without(obj, A)
376+
>>> instances, unpickleable
377+
([<A>, <A>, <A>], [<NS>])
378+
>>> unpickle_without(pik, instances, unpickleable)
379+
{1: <A>, 2: [<A>, <NS>, <A>]}
380380
381381
This can be also used to hot-swap inner objects; the only constraint is that
382382
the number of objects in and out must be the same:
383383
384384
>>> class B:
385385
... def __repr__(self): return "<B>"
386-
>>> unpickle_without(pik, [B(), B(), B()])
387-
{1: <B>, 2: [<B>, <B>]}
386+
>>> unpickle_without(pik, [B(), B(), B()], [NS()])
387+
{1: <B>, 2: [<B>, <NS>, <B>]}
388388
"""
389-
extracted = []
390-
391-
def reduce(x: T) -> tuple[Callable[[], object], tuple[()]]: # numpydoc ignore=GL08
392-
extracted.append(x)
393-
return _expand, ()
389+
instances: list[T] = []
390+
unpickleable: list[object] = []
391+
seen: dict[int, Literal[0, 1, None]] = {}
392+
393+
class Pickler(pickle.Pickler): # numpydoc ignore=GL01,RT01
394+
"""Override pickle.Pickler.persistent_id.
395+
396+
TODO consider moving to top-level scope to allow for
397+
the full Pickler API to be used.
398+
"""
399+
400+
@override
401+
def persistent_id(self, obj: object) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
402+
# Fast exit in case of basic builtin types.
403+
# Note that basic collections (tuple, list, etc.) are in this;
404+
# persistent_id() will be called again with their contents.
405+
if type(obj) in _BASIC_TYPES: # No subclasses!
406+
return None
407+
408+
id_ = id(obj)
409+
try:
410+
kind = seen[id_]
411+
return None if kind is None else (id_, kind)
412+
except KeyError:
413+
pass
414+
415+
if isinstance(obj, cls):
416+
instances.append(obj) # type: ignore[arg-type]
417+
seen[id_] = 0
418+
return id_, 0
419+
420+
for func in (
421+
# Note: a class that defines __slots__ without defining __getstate__
422+
# cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
423+
lambda: obj.__reduce_ex__(pickle.HIGHEST_PROTOCOL),
424+
lambda: obj.__reduce__(),
425+
# Global functions don't have __reduce__, which can be pickled
426+
lambda: pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL),
427+
):
428+
try:
429+
# a class that defines __slots__ without defining __getstate__
430+
# cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
431+
func()
432+
except Exception: # pylint: disable=broad-exception-caught
433+
pass
434+
else: # Can be pickled
435+
seen[id_] = None
436+
return None
437+
438+
# Can't be pickled
439+
unpickleable.append(obj)
440+
seen[id_] = 1
441+
return id_, 1
394442

395443
f = io.BytesIO()
396-
p = pickle.Pickler(f)
397-
398-
# Override the reducer for the given classes and all their
399-
# subclasses (recursively).
400-
p.dispatch_table = copyreg.dispatch_table.copy()
401-
subclasses = list(classes)
402-
while subclasses:
403-
cls = subclasses.pop()
404-
p.dispatch_table[cls] = reduce
405-
subclasses.extend(cls.__subclasses__())
406-
444+
p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL)
407445
p.dump(obj)
446+
return f.getvalue(), tuple(instances), tuple(unpickleable)
408447

409-
return f.getvalue(), extracted
410448

411-
412-
def unpickle_without(pik: bytes, objects: Iterable[object], /) -> Any: # type: ignore[explicit-any]
449+
def unpickle_without( # type: ignore[explicit-any]
450+
pik: bytes,
451+
instances: Iterable[object],
452+
unpickleable: Iterable[object],
453+
/,
454+
) -> Any:
413455
"""
414456
Variant of ``pickle.loads``, reverse of ``pickle_without``.
415457
416458
Parameters
417459
----------
418460
pik : bytes
419461
The pickled object generated by ``pickle_without``.
420-
objects : Iterable
421-
The objects to be reinserted into the unpickled object.
422-
Must be the at least the same number of elements as the ones extracted by
423-
``pickle_without``, but does not need to be the same objects or even the
424-
same types of objects. Excess objects, if any, won't be inserted.
462+
instances : Iterable[object]
463+
Instances of the class or classes explicitly passed to ``pickle_without``,
464+
to be reinserted into the unpickled object.
465+
unpickleable : Iterable[object]
466+
The objects that failed to pickle, as returned by ``pickle_without``.
425467
426468
Returns
427469
-------
428470
object
429-
The unpickled object, with the objects in ``objects`` inserted back into it.
471+
The unpickled object.
430472
431473
See Also
432474
--------
433475
pickle_without : Serializing function.
434476
pickle.loads : Standard unpickle function.
477+
478+
Notes
479+
-----
480+
The second and third parameter of this function must yield at least the same number
481+
of elements as the ones returned by ``pickle_without``, but do not need to be the
482+
same objects, or even the same types of objects. Excess objects, if any, will be
483+
quietly ignored.
484+
"""
485+
iters = iter(instances), iter(unpickleable)
486+
seen: dict[tuple[int, int], object] = {}
487+
488+
class Unpickler(pickle.Unpickler): # numpydoc ignore=GL01,RT01
489+
"""
490+
Override pickle.Pickler.persistent_load.
491+
492+
TODO consider moving to top-level scope to allow for
493+
the full Unpickler API to be used.
494+
"""
495+
496+
@override
497+
def persistent_load(self, pid: tuple[int, int]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
498+
try:
499+
return seen[pid]
500+
except KeyError:
501+
pass
502+
503+
_, kind = pid
504+
try:
505+
obj = next(iters[kind])
506+
except StopIteration as e:
507+
msg = "Not enough objects to unpickle"
508+
raise ValueError(msg) from e
509+
510+
seen[pid] = obj
511+
return obj
512+
513+
f = io.BytesIO(pik)
514+
return Unpickler(f).load()
515+
516+
517+
def jax_autojit(
518+
func: Callable[P, T],
519+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03
520+
"""
521+
Wrap `func` with ``jax.jit``, with the following differences:
522+
523+
- Array-like arguments and return values are not automatically converted to
524+
``jax.Array`` objects.
525+
- All non-array arguments are automatically treated as static.
526+
Unlike ``jax.jit``, static arguments must be either hashable or serializable with
527+
``pickle``.
528+
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
529+
tuple/list/dict, but can be any object serializable with ``pickle``.
530+
- Automatically descend into non-array arguments and find ``jax.Array`` objects
531+
inside them, then rebuild the arguments when entering `func`, swapping the JAX
532+
concrete arrays with tracer objects.
533+
- Automatically descend into non-array return values and find ``jax.Array`` objects
534+
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
535+
tracer objects with concrete arrays.
435536
"""
436-
tok = _repacking_objects.set(iter(objects))
437-
try:
438-
return pickle.loads(pik)
439-
finally:
440-
_repacking_objects.reset(tok)
537+
import jax
538+
539+
# {
540+
# jit_cache_key(args_pik, args_arrays, args_unpickleable):
541+
# (res_pik, res_unpickleable)
542+
# }
543+
static_return_values: dict[Hashable, tuple[bytes, tuple[object, ...]]] = {}
544+
545+
def jit_cache_key( # type: ignore[no-any-unimported] # numpydoc ignore=GL08
546+
args_pik: bytes,
547+
args_arrays: tuple[jax.Array, ...], # pyright: ignore[reportUnknownParameterType]
548+
args_unpickleable: tuple[Hashable, ...],
549+
) -> Hashable:
550+
return (
551+
args_pik,
552+
tuple((arr.shape, arr.dtype) for arr in args_arrays), # pyright: ignore[reportUnknownArgumentType]
553+
args_unpickleable,
554+
)
555+
556+
def inner( # type: ignore[no-any-unimported] # pyright: ignore[reportUnknownParameterType]
557+
args_pik: bytes,
558+
args_arrays: tuple[jax.Array, ...], # pyright: ignore[reportUnknownParameterType]
559+
args_unpickleable: tuple[Hashable, ...],
560+
) -> tuple[jax.Array, ...]: # numpydoc ignore=GL08
561+
args, kwargs = unpickle_without(args_pik, args_arrays, args_unpickleable) # pyright: ignore[reportUnknownArgumentType]
562+
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
563+
res_pik, res_arrays, res_unpickleable = pickle_without(res, jax.Array) # pyright: ignore[reportUnknownArgumentType]
564+
key = jit_cache_key(args_pik, args_arrays, args_unpickleable)
565+
val = res_pik, res_unpickleable
566+
prev = static_return_values.setdefault(key, val)
567+
assert prev == val, "cache key collision"
568+
return res_arrays
569+
570+
jitted = jax.jit(inner, static_argnums=(0, 2))
571+
572+
@wraps(func)
573+
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
574+
args_pik, args_arrays, args_unpickleable = pickle_without(
575+
(args, kwargs),
576+
jax.Array, # pyright: ignore[reportUnknownArgumentType]
577+
)
578+
res_arrays = jitted(args_pik, args_arrays, args_unpickleable)
579+
key = jit_cache_key(args_pik, args_arrays, args_unpickleable)
580+
res_pik, res_unpickleable = static_return_values[key]
581+
return unpickle_without(res_pik, res_arrays, res_unpickleable) # pyright: ignore[reportUnknownArgumentType]
582+
583+
return outer

0 commit comments

Comments
 (0)