Skip to content

Commit 43bcb26

Browse files
committed
light version
1 parent ebcdaca commit 43bcb26

File tree

2 files changed

+175
-11
lines changed

2 files changed

+175
-11
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
from __future__ import annotations
44

5+
import copyreg
6+
import io
57
import math
6-
from collections.abc import Generator, Iterable
8+
import pickle
9+
from collections.abc import Callable, Generator, Iterable, Iterator
10+
from contextvars import ContextVar
711
from types import ModuleType
8-
from typing import TYPE_CHECKING, cast
12+
from typing import TYPE_CHECKING, Any, TypeVar, cast
913

1014
from . import _compat
1115
from ._compat import (
@@ -22,6 +26,8 @@
2226
# TODO import from typing (requires Python >=3.13)
2327
from typing_extensions import TypeIs
2428

29+
T = TypeVar("T")
30+
2531

2632
__all__ = [
2733
"asarrays",
@@ -31,6 +37,8 @@
3137
"is_python_scalar",
3238
"mean",
3339
"meta_namespace",
40+
"pickle_without",
41+
"unpickle_without",
3442
]
3543

3644

@@ -306,3 +314,127 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
306314
out["boolean indexing"] = True
307315
out["data-dependent shapes"] = True
308316
return out
317+
318+
319+
# Helper of ``extract_objects`` and ``repack_objects``
320+
_repacking_objects: ContextVar[Iterator[object]] = ContextVar("_repacking_objects")
321+
322+
323+
def _expand() -> object: # numpydoc ignore=RT01
324+
"""
325+
Helper of ``extract_objects`` and ``repack_objects``.
326+
327+
Inverse of the reducer function.
328+
329+
Notes
330+
-----
331+
This function must be global in order to be picklable.
332+
"""
333+
try:
334+
return next(_repacking_objects.get())
335+
except StopIteration:
336+
msg = "Not enough objects to repack"
337+
raise ValueError(msg)
338+
339+
340+
def pickle_without(obj: object, *classes: type[T]) -> tuple[bytes, list[T]]:
341+
"""
342+
Variant of ``pickle.dumps`` that extracts inner objects.
343+
344+
Conceptually, this is similar to passing the ``buffer_callback`` argument to
345+
``pickle.dumps``, but instead of extracting buffers it extracts entire objects.
346+
347+
Parameters
348+
----------
349+
obj : object
350+
The object to pickle.
351+
*classes : type
352+
One or more classes to extract from the object.
353+
The instances of these classes inside ``obj`` will not be pickled.
354+
355+
Returns
356+
-------
357+
bytes
358+
The pickled object. Must be unpickled with :func:`unpickle_without`.
359+
list
360+
All instances of ``classes`` found inside ``obj`` (not pickled).
361+
362+
See Also
363+
--------
364+
pickle.dumps : Standard pickle function.
365+
unpickle_without : Reverse function.
366+
367+
Examples
368+
--------
369+
>>> class A:
370+
... def __repr__(self):
371+
... return "<A>"
372+
... def __reduce__(self):
373+
... assert False, "Not serializable"
374+
>>> obj = {1: A(), 2: [A(), A()]} # Any serializable object
375+
>>> pik, extracted = pickle_without(obj, A)
376+
>>> extracted
377+
[<A>, <A>, <A>]
378+
>>> unpickle_without(pik, extracted)
379+
{1: <A>, 2: [<A>, <A>]}
380+
381+
This can be also used to hot-swap inner objects; the only constraint is that
382+
the number of objects in and out must be the same:
383+
384+
>>> class B:
385+
... def __repr__(self): return "<B>"
386+
>>> unpickle_without(pik, [B(), B(), B()])
387+
{1: <B>, 2: [<B>, <B>]}
388+
"""
389+
extracted = []
390+
391+
def reduce(x: T) -> tuple[Callable[[], object], tuple[()]]: # numpydoc ignore=GL08
392+
extracted.append(x)
393+
return _expand, ()
394+
395+
f = io.BytesIO()
396+
p = pickle.Pickler(f)
397+
398+
# Override the reducer for the given classes and all their
399+
# subclasses (recursively).
400+
p.dispatch_table = copyreg.dispatch_table.copy()
401+
subclasses = list(classes)
402+
while subclasses:
403+
cls = subclasses.pop()
404+
p.dispatch_table[cls] = reduce
405+
subclasses.extend(cls.__subclasses__())
406+
407+
p.dump(obj)
408+
409+
return f.getvalue(), extracted
410+
411+
412+
def unpickle_without(pik: bytes, objects: Iterable[object], /) -> Any: # type: ignore[explicit-any]
413+
"""
414+
Variant of ``pickle.loads``, reverse of ``pickle_without``.
415+
416+
Parameters
417+
----------
418+
pik : bytes
419+
The pickled object generated by ``pickle_without``.
420+
objects : Iterable
421+
The objects to be reinserted into the unpickled object.
422+
Must be the at least the same number of elements as the ones extracted by
423+
``pickle_without``, but does not need to be the same objects or even the
424+
same types of objects. Excess objects, if any, won't be inserted.
425+
426+
Returns
427+
-------
428+
object
429+
The unpickled object, with the objects in ``objects`` inserted back into it.
430+
431+
See Also
432+
--------
433+
pickle_without : Serializing function.
434+
pickle.loads : Standard unpickle function.
435+
"""
436+
tok = _repacking_objects.set(iter(objects))
437+
try:
438+
return pickle.loads(pik)
439+
finally:
440+
_repacking_objects.reset(tok)

src/array_api_extra/testing.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
1414

1515
from ._lib._utils._compat import is_dask_namespace, is_jax_namespace
16+
from ._lib._utils._helpers import pickle_without, unpickle_without
1617

1718
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
1819

@@ -243,14 +244,10 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
243244

244245
for mod, name, func, tags in iter_tagged():
245246
if tags["jax_jit"]:
246-
# suppress unused-ignore to run mypy in -e lint as well as -e dev
247-
wrapped = cast( # type: ignore[explicit-any]
248-
Callable[..., Any],
249-
jax.jit(
250-
func,
251-
static_argnums=tags["static_argnums"],
252-
static_argnames=tags["static_argnames"],
253-
),
247+
wrapped = _jax_wrap(
248+
func,
249+
static_argnums=tags["static_argnums"],
250+
static_argnames=tags["static_argnames"],
254251
)
255252
monkeypatch.setattr(mod, name, wrapped)
256253

@@ -300,6 +297,7 @@ def _dask_wrap(
300297
After the function returns, materialize the graph in order to re-raise exceptions.
301298
"""
302299
import dask
300+
import dask.array as da
303301

304302
func_name = getattr(func, "__name__", str(func))
305303
n_str = f"only up to {n}" if n else "no"
@@ -319,6 +317,40 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
319317
# Block until the graph materializes and reraise exceptions. This allows
320318
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would
321319
# not work on scheduler='distributed', as it would not block.
322-
return dask.persist(out, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
320+
pik, arrays = pickle_without(out, da.Array)
321+
arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
322+
unpickle_without(pik, arrays)
323323

324324
return wrapper
325+
326+
327+
def _jax_wrap(
328+
func: Callable[P, T],
329+
static_argnums: int | Sequence[int] | None,
330+
static_argnames: str | Iterable[str] | None,
331+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
332+
"""
333+
Wrap `func` inside ``jax.jit``.
334+
335+
Accepts non-array return values.
336+
"""
337+
import jax
338+
import jax.numpy as jnp
339+
340+
def inner(*args: P.args, **kwargs: P.kwargs) -> tuple[jax.Array, ...]: # numpydoc ignore=GL08
341+
out = func(*args, **kwargs)
342+
pik, arrays = pickle_without(out, jax.Array)
343+
return jnp.frombuffer(pik, dtype=jnp.uint8), *arrays
344+
345+
jitted = jax.jit(
346+
inner, static_argnums=static_argnums, static_argnames=static_argnames
347+
)
348+
cpu = jax.devices("cpu")[0]
349+
350+
@wraps(func)
351+
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
352+
arrays = jitted(*args, **kwargs)
353+
pik = bytes(arrays[0].to_device(cpu))
354+
return unpickle_without(pik, arrays[1:])
355+
356+
return outer

0 commit comments

Comments
 (0)