Skip to content

Commit b034e94

Browse files
committed
light version
1 parent 9c9326a commit b034e94

File tree

2 files changed

+164
-11
lines changed

2 files changed

+164
-11
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
from __future__ import annotations
44

5+
import copyreg
6+
import io
57
import math
6-
from collections.abc import Generator, Iterable
8+
import pickle
9+
from collections.abc import Callable, Generator, Iterable, Iterator
10+
from contextvars import ContextVar
711
from types import ModuleType
8-
from typing import TYPE_CHECKING, cast
12+
from typing import TYPE_CHECKING, Any, TypeVar, cast
913

1014
from . import _compat
1115
from ._compat import (
@@ -22,6 +26,8 @@
2226
# TODO import from typing (requires Python >=3.13)
2327
from typing_extensions import TypeIs
2428

29+
T = TypeVar("T")
30+
2531

2632
__all__ = [
2733
"asarrays",
@@ -31,6 +37,8 @@
3137
"is_python_scalar",
3238
"mean",
3339
"meta_namespace",
40+
"pickle_without",
41+
"unpickle_without",
3442
]
3543

3644

@@ -306,3 +314,116 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
306314
out["boolean indexing"] = True
307315
out["data-dependent shapes"] = True
308316
return out
317+
318+
319+
# Helper of ``extract_objects`` and ``repack_objects``
320+
_repacking_objects: ContextVar[Iterator[object]] = ContextVar("_repacking_objects")
321+
322+
323+
def _expand() -> object: # numpydoc ignore=RT01
324+
"""
325+
Helper of ``extract_objects`` and ``repack_objects``.
326+
327+
Inverse of the reducer function.
328+
329+
Notes
330+
-----
331+
This function must be global in order to be picklable.
332+
"""
333+
return next(_repacking_objects.get())
334+
335+
336+
def pickle_without(obj: object, *classes: type[T]) -> tuple[bytes, list[T]]:
337+
"""
338+
Variant of ``pickle.dumps`` that extracts inner objects.
339+
340+
Conceptually, this is similar as passing the ``buffer_callback`` argument to
341+
``pickle.dumps``, but instead of extracting buffers it extracts entire objects.
342+
343+
Parameters
344+
----------
345+
obj : object
346+
The object to pickle.
347+
*classes : type
348+
One or more classes to extract from the object.
349+
The instances of these classes inside ``obj`` will not be pickled.
350+
351+
Returns
352+
-------
353+
bytes
354+
The pickled object. Must be unpickled with :func:`unpickle_without`.
355+
list
356+
All instances of ``classes`` found inside ``obj`` (not pickled).
357+
358+
See Also
359+
--------
360+
pickle.dumps : Standard pickle function.
361+
unpickle_without : Reverse function.
362+
363+
Examples
364+
--------
365+
>>> class A:
366+
... def __repr__(self):
367+
... return "<A>"
368+
... def __reduce__(self):
369+
... assert False, "Not serializable"
370+
>>> obj = {1: A(), 2: [A(), A()]} # Any serializable object
371+
>>> pik, extracted = pickle_without(obj, A)
372+
>>> extracted
373+
[<A>, <A>, <A>]
374+
>>> unpickle_without(pik, extracted)
375+
{1: <A>, 2: [<A>, <A>]}
376+
377+
This can be also used to hot-swap inner objects; the only constraint is that
378+
the number of objects in and out must be the same:
379+
380+
>>> class B:
381+
... def __repr__(self): return "<B>"
382+
>>> unpickle_without(pik, [B(), B(), B()])
383+
{1: <B>, 2: [<B>, <B>]}
384+
"""
385+
extracted = []
386+
387+
def reduce(x: T) -> tuple[Callable[[], object], tuple[()]]: # numpydoc ignore=GL08
388+
extracted.append(x)
389+
return _expand, ()
390+
391+
f = io.BytesIO()
392+
p = pickle.Pickler(f)
393+
p.dispatch_table = copyreg.dispatch_table.copy()
394+
for cls in classes:
395+
p.dispatch_table[cls] = reduce
396+
p.dump(obj)
397+
398+
return f.getvalue(), extracted
399+
400+
401+
def unpickle_without(pik: bytes, objects: Iterable[object], /) -> Any: # type: ignore[explicit-any]
402+
"""
403+
Variant of ``pickle.loads``, reverse of ``pickle_without``.
404+
405+
Parameters
406+
----------
407+
pik : bytes
408+
The pickled object generated by ``pickle_without``.
409+
objects : Iterable
410+
The objects to be reinserted into the unpickled object.
411+
Must be the at least the same number of elements as the ones extracted by
412+
``pickle_without``, but does not need to be the same objects or even the
413+
same types of objects. Excess objects, if any, won't be inserted.
414+
415+
Returns
416+
-------
417+
object
418+
The unpickled object, with the objects in ``objects`` inserted back into it.
419+
420+
See Also
421+
--------
422+
pickle_without : Serializing function.
423+
pickle.loads : Standard unpickle function.
424+
"""
425+
tok = _repacking_objects.set(iter(objects))
426+
try:
427+
return pickle.loads(pik)
428+
finally:
429+
_repacking_objects.reset(tok)

src/array_api_extra/testing.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
1414

1515
from ._lib._utils._compat import is_dask_namespace, is_jax_namespace
16+
from ._lib._utils._helpers import pickle_without, unpickle_without
1617

1718
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
1819

@@ -243,14 +244,10 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
243244

244245
for mod, name, func, tags in iter_tagged():
245246
if tags["jax_jit"]:
246-
# suppress unused-ignore to run mypy in -e lint as well as -e dev
247-
wrapped = cast( # type: ignore[explicit-any]
248-
Callable[..., Any],
249-
jax.jit(
250-
func,
251-
static_argnums=tags["static_argnums"],
252-
static_argnames=tags["static_argnames"],
253-
),
247+
wrapped = _jax_wrap(
248+
func,
249+
static_argnums=tags["static_argnums"],
250+
static_argnames=tags["static_argnames"],
254251
)
255252
monkeypatch.setattr(mod, name, wrapped)
256253

@@ -300,6 +297,7 @@ def _dask_wrap(
300297
After the function returns, materialize the graph in order to re-raise exceptions.
301298
"""
302299
import dask
300+
import dask.array as da
303301

304302
func_name = getattr(func, "__name__", str(func))
305303
n_str = f"only up to {n}" if n else "no"
@@ -319,6 +317,40 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
319317
# Block until the graph materializes and reraise exceptions. This allows
320318
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would
321319
# not work on scheduler='distributed', as it would not block.
322-
return dask.persist(out, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
320+
pik, arrays = pickle_without(out, da.Array)
321+
arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
322+
unpickle_without(pik, arrays)
323323

324324
return wrapper
325+
326+
327+
def _jax_wrap(
328+
func: Callable[P, T],
329+
static_argnums: int | Sequence[int] | None,
330+
static_argnames: str | Iterable[str] | None,
331+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
332+
"""
333+
Wrap `func` inside ``jax.jit``.
334+
335+
Accepts non-array return values.
336+
"""
337+
import jax
338+
import jax.numpy as jnp
339+
340+
def inner(*args: P.args, **kwargs: P.kwargs) -> tuple[jax.Array, ...]: # numpydoc ignore=GL08
341+
out = func(*args, **kwargs)
342+
pik, arrays = pickle_without(out, jax.Array)
343+
return jnp.frombuffer(pik, dtype=jnp.uint8), *arrays
344+
345+
jitted = jax.jit(
346+
inner, static_argnums=static_argnums, static_argnames=static_argnames
347+
)
348+
cpu = jax.devices("cpu")[0]
349+
350+
@wraps(func)
351+
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
352+
arrays = jitted(*args, **kwargs)
353+
pik = bytes(arrays[0].to_device(cpu))
354+
return unpickle_without(pik, arrays[1:])
355+
356+
return outer

0 commit comments

Comments
 (0)