Skip to content

Commit e062f09

Browse files
committed
support unpickleable
1 parent 02c579d commit e062f09

File tree

3 files changed

+171
-95
lines changed

3 files changed

+171
-95
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 131 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22

33
from __future__ import annotations
44

5-
import copyreg
65
import io
76
import math
87
import pickle
9-
from collections.abc import Callable, Generator, Iterable, Iterator
10-
from contextvars import ContextVar
8+
from collections.abc import Generator, Iterable
119
from types import ModuleType
12-
from typing import TYPE_CHECKING, Any, TypeVar, cast
10+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
1311

1412
from . import _compat
1513
from ._compat import (
@@ -23,8 +21,13 @@
2321
from ._typing import Array
2422

2523
if TYPE_CHECKING: # pragma: no cover
26-
# TODO import from typing (requires Python >=3.13)
27-
from typing_extensions import TypeIs
24+
# TODO import from typing (requires Python >=3.12 and >=3.13)
25+
from typing_extensions import TypeIs, override
26+
else:
27+
28+
def override(func):
29+
return func
30+
2831

2932
T = TypeVar("T")
3033

@@ -316,48 +319,33 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
316319
return out
317320

318321

319-
# Helper of ``extract_objects`` and ``repack_objects``
320-
_repacking_objects: ContextVar[Iterator[object]] = ContextVar("_repacking_objects")
321-
322-
323-
def _expand() -> object: # numpydoc ignore=RT01
324-
"""
325-
Helper of ``extract_objects`` and ``repack_objects``.
326-
327-
Inverse of the reducer function.
328-
329-
Notes
330-
-----
331-
This function must be global in order to be picklable.
332-
"""
333-
try:
334-
return next(_repacking_objects.get())
335-
except StopIteration:
336-
msg = "Not enough objects to repack"
337-
raise ValueError(msg)
338-
339-
340-
def pickle_without(obj: object, *classes: type[T]) -> tuple[bytes, list[T]]:
322+
def pickle_without(
323+
obj: object, cls: type[T] | tuple[type[T], ...] = ()
324+
) -> tuple[bytes, tuple[T, ...], tuple[object, ...]]:
341325
"""
342-
Variant of ``pickle.dumps`` that extracts inner objects.
326+
Variant of ``pickle.dumps`` that always succeeds and extracts inner objects.
343327
344328
Conceptually, this is similar to passing the ``buffer_callback`` argument to
345-
``pickle.dumps``, but instead of extracting buffers it extracts entire objects.
329+
``pickle.dumps``, but instead of extracting buffers it extracts entire objects,
330+
which are either not serializable with ``pickle`` (e.g. local classes or functions)
331+
or instances of an explicit list of classes.
346332
347333
Parameters
348334
----------
349335
obj : object
350336
The object to pickle.
351-
*classes : type
352-
One or more classes to extract from the object.
337+
cls : type | tuple[type, ...], optional
338+
One or multiple classes to extract from the object.
353339
The instances of these classes inside ``obj`` will not be pickled.
354340
355341
Returns
356342
-------
357343
bytes
358344
The pickled object. Must be unpickled with :func:`unpickle_without`.
359-
list
360-
All instances of ``classes`` found inside ``obj`` (not pickled).
345+
tuple
346+
All instances of ``cls`` found inside ``obj`` (not pickled).
347+
tuple
348+
All other objects which failed to pickle.
361349
362350
See Also
363351
--------
@@ -366,75 +354,144 @@ def pickle_without(obj: object, *classes: type[T]) -> tuple[bytes, list[T]]:
366354
367355
Examples
368356
--------
357+
>>> class NS:
358+
... def __repr__(self):
359+
... return "<NS>"
360+
... def __reduce__(self):
361+
... assert False
369362
>>> class A:
370363
... def __repr__(self):
371364
... return "<A>"
372-
... def __reduce__(self):
373-
... assert False, "Not serializable"
374-
>>> obj = {1: A(), 2: [A(), A()]} # Any serializable object
375-
>>> pik, extracted = pickle_without(obj, A)
376-
>>> extracted
377-
[<A>, <A>, <A>]
378-
>>> unpickle_without(pik, extracted)
379-
{1: <A>, 2: [<A>, <A>]}
365+
>>> obj = {1: A(), 2: [A(), NS(), A()]} # Any serializable object
366+
>>> pik, instances, unpickleable = pickle_without(obj, A)
367+
>>> instances, unpickleable
368+
([<A>, <A>, <A>], [<NS>])
369+
>>> unpickle_without(pik, instances, unpickleable)
370+
{1: <A>, 2: [<A>, <NS>, <A>]}
380371
381372
This can be also used to hot-swap inner objects; the only constraint is that
382373
the number of objects in and out must be the same:
383374
384375
>>> class B:
385376
... def __repr__(self): return "<B>"
386-
>>> unpickle_without(pik, [B(), B(), B()])
387-
{1: <B>, 2: [<B>, <B>]}
377+
>>> unpickle_without(pik, [B(), B(), B()], [NS()])
378+
{1: <B>, 2: [<B>, <NS>, <B>]}
388379
"""
389-
extracted = []
390-
391-
def reduce(x: T) -> tuple[Callable[[], object], tuple[()]]: # numpydoc ignore=GL08
392-
extracted.append(x)
393-
return _expand, ()
380+
instances: list[T] = []
381+
unpickleable: list[object] = []
382+
seen: dict[int, Literal[0, 1, None]] = {}
383+
384+
class Pickler(pickle.Pickler): # numpydoc ignore=GL01,RT01
385+
"""Override pickle.Pickler.persistent_id.
386+
387+
TODO consider moving to top-level scope to allow for
388+
the full Pickler API to be used.
389+
"""
390+
391+
@override
392+
def persistent_id(self, obj: object) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
393+
id_ = id(obj)
394+
try:
395+
return seen[id_]
396+
except KeyError:
397+
pass
398+
399+
if isinstance(obj, cls):
400+
instances.append(obj) # type: ignore[arg-type]
401+
seen[id_] = 0
402+
return id_, 0
403+
404+
try:
405+
_ = obj.__reduce__()
406+
except Exception: # pylint: disable=broad-exception-caught
407+
pass
408+
else: # Can be pickled
409+
seen[id_] = None
410+
return None
411+
412+
# May be a global function, which can be pickled
413+
try:
414+
_ = pickle.dumps(obj)
415+
except Exception: # pylint: disable=broad-exception-caught
416+
pass
417+
else: # Can be pickled
418+
seen[id_] = None
419+
return None
420+
421+
# Can't be pickled
422+
unpickleable.append(obj)
423+
seen[id_] = 1
424+
return id_, 1
394425

395426
f = io.BytesIO()
396-
p = pickle.Pickler(f)
397-
398-
# Override the reducer for the given classes and all their
399-
# subclasses (recursively).
400-
p.dispatch_table = copyreg.dispatch_table.copy()
401-
subclasses = list(classes)
402-
while subclasses:
403-
cls = subclasses.pop()
404-
p.dispatch_table[cls] = reduce
405-
subclasses.extend(cls.__subclasses__())
406-
427+
p = Pickler(f)
407428
p.dump(obj)
429+
return f.getvalue(), tuple(instances), tuple(unpickleable)
408430

409-
return f.getvalue(), extracted
410431

411-
412-
def unpickle_without(pik: bytes, objects: Iterable[object], /) -> Any: # type: ignore[explicit-any]
432+
def unpickle_without( # type: ignore[explicit-any]
433+
pik: bytes,
434+
instances: Iterable[object],
435+
unpickleable: Iterable[object],
436+
/,
437+
) -> Any:
413438
"""
414439
Variant of ``pickle.loads``, reverse of ``pickle_without``.
415440
416441
Parameters
417442
----------
418443
pik : bytes
419444
The pickled object generated by ``pickle_without``.
420-
objects : Iterable
421-
The objects to be reinserted into the unpickled object.
422-
Must be the at least the same number of elements as the ones extracted by
423-
``pickle_without``, but does not need to be the same objects or even the
424-
same types of objects. Excess objects, if any, won't be inserted.
445+
instances : Iterable[object]
446+
Instances of the class or classes explicitly passed to ``pickle_without``,
447+
to be reinserted into the unpickled object.
448+
unpickleable : Iterable[object]
449+
The objects that failed to pickle, as returned by ``pickle_without``.
425450
426451
Returns
427452
-------
428453
object
429-
The unpickled object, with the objects in ``objects`` inserted back into it.
454+
The unpickled object.
430455
431456
See Also
432457
--------
433458
pickle_without : Serializing function.
434459
pickle.loads : Standard unpickle function.
460+
461+
Notes
462+
-----
463+
The second and third parameter of this function must yield at least the same number
464+
of elements as the ones returned by ``pickle_without``, but do not need to be the
465+
same objects, or even the same types of objects. Excess objects, if any, will be
466+
quietly ignored.
435467
"""
436-
tok = _repacking_objects.set(iter(objects))
437-
try:
438-
return pickle.loads(pik)
439-
finally:
440-
_repacking_objects.reset(tok)
468+
iters = iter(instances), iter(unpickleable)
469+
seen: dict[int, object] = {}
470+
471+
class Unpickler(pickle.Unpickler): # numpydoc ignore=GL01,RT01
472+
"""
473+
Override pickle.Pickler.persistent_load.
474+
475+
TODO consider moving to top-level scope to allow for
476+
the full Unpickler API to be used.
477+
"""
478+
479+
@override
480+
def persistent_load(self, pid: tuple[int, int]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
481+
prev_id, kind = pid
482+
try:
483+
return seen[prev_id]
484+
except KeyError:
485+
pass
486+
487+
try:
488+
obj = next(iters[kind])
489+
except StopIteration as e:
490+
msg = "Not enough objects to unpickle"
491+
raise ValueError(msg) from e
492+
493+
seen[prev_id] = obj
494+
return obj
495+
496+
f = io.BytesIO(pik)
497+
return Unpickler(f).load()

src/array_api_extra/testing.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
329329
# Block until the graph materializes and reraise exceptions. This allows
330330
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would
331331
# not work on scheduler='distributed', as it would not block.
332-
pik, arrays = pickle_without(out, da.Array)
332+
pik, arrays, unpickleable = pickle_without(out, da.Array)
333333
arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
334-
return unpickle_without(pik, arrays)
334+
return unpickle_without(pik, arrays, unpickleable) # pyright: ignore[reportUnknownArgumentType]
335335

336336
return wrapper
337337

@@ -343,6 +343,8 @@ def _jax_autojit(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR0
343343
- Array-like arguments and return values are not automatically converted to
344344
``jax.Array`` objects.
345345
- All non-array arguments are automatically treated as static.
346+
Unlike ``jax.jit``, static arguments must be either hashable or serializable with
347+
``pickle``.
346348
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
347349
tuple/list/dict, but can be any object serializable with ``pickle``.
348350
- Automatically descend into non-array arguments and find ``jax.Array`` objects
@@ -354,27 +356,44 @@ def _jax_autojit(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR0
354356
import jax
355357

356358
# pickled return values of `func`, minus the JAX arrays
357-
res_piks = {}
358-
359-
def jit_cache_key(args_pik: bytes, *arg_arrays: jax.Array) -> Hashable: # type: ignore[no-any-unimported]
360-
return (args_pik, *((arr.shape, arr.dtype) for arr in arg_arrays))
361-
362-
def inner(args_pik: bytes, *arg_arrays: jax.Array) -> list[jax.Array]: # type: ignore[no-any-unimported] # numpydoc ignore=GL08
363-
args, kwargs = unpickle_without(args_pik, arg_arrays)
364-
res = func(*args, **kwargs)
365-
res_pik, res_arrays = pickle_without(res, jax.Array)
366-
key = jit_cache_key(args_pik, *arg_arrays)
367-
prev = res_piks.setdefault(key, res_pik)
368-
assert prev == res_pik, "cache key collision"
359+
static_return_values: dict[Hashable, tuple[bytes, tuple[object, ...]]] = {}
360+
361+
def jit_cache_key( # type: ignore[no-any-unimported]
362+
args_pik: bytes,
363+
args_arrays: tuple[jax.Array, ...], # pyright: ignore[reportUnknownParameterType]
364+
args_unpickleable: tuple[Hashable, ...],
365+
) -> Hashable:
366+
return (
367+
args_pik,
368+
tuple((arr.shape, arr.dtype) for arr in args_arrays), # pyright: ignore[reportUnknownArgumentType]
369+
args_unpickleable,
370+
)
371+
372+
def inner( # type: ignore[no-any-unimported] # pyright: ignore[reportUnknownParameterType]
373+
args_pik: bytes,
374+
args_arrays: tuple[jax.Array, ...], # pyright: ignore[reportUnknownParameterType]
375+
args_unpickleable: tuple[Hashable, ...],
376+
) -> tuple[jax.Array, ...]: # numpydoc ignore=GL08
377+
args, kwargs = unpickle_without(args_pik, args_arrays, args_unpickleable) # pyright: ignore[reportUnknownArgumentType]
378+
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
379+
res_pik, res_arrays, res_unpickleable = pickle_without(res, jax.Array) # pyright: ignore[reportUnknownArgumentType]
380+
key = jit_cache_key(args_pik, args_arrays, args_unpickleable)
381+
val = res_pik, res_unpickleable
382+
prev = static_return_values.setdefault(key, val)
383+
assert prev == val, "cache key collision"
369384
return res_arrays
370385

371-
jitted = jax.jit(inner, static_argnums=0)
386+
jitted = jax.jit(inner, static_argnums=(0, 2))
372387

373388
@wraps(func)
374389
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
375-
args_pik, arg_arrays = pickle_without((args, kwargs), jax.Array)
376-
res_arrays = jitted(args_pik, *arg_arrays)
377-
res_pik = res_piks[jit_cache_key(args_pik, *arg_arrays)]
378-
return unpickle_without(res_pik, res_arrays)
390+
args_pik, args_arrays, args_unpickleable = pickle_without(
391+
(args, kwargs),
392+
jax.Array, # pyright: ignore[reportUnknownArgumentType]
393+
)
394+
res_arrays = jitted(args_pik, args_arrays, args_unpickleable)
395+
key = jit_cache_key(args_pik, args_arrays, args_unpickleable)
396+
res_pik, res_unpickleable = static_return_values[key]
397+
return unpickle_without(res_pik, res_arrays, res_unpickleable) # pyright: ignore[reportUnknownArgumentType]
379398

380399
return outer

tests/test_testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,9 @@ def test_lazy_xp_function_static_params(xp: ModuleType):
194194

195195
def test_lazy_xp_function_deprecated_static_argnames():
196196
with pytest.warns(DeprecationWarning, match="static_argnames"):
197-
lazy_xp_function(static_params, static_argnames=["flag"]) # type: ignore[arg-type]
197+
lazy_xp_function(static_params, static_argnames=["flag"]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
198198
with pytest.warns(DeprecationWarning, match="static_argnums"):
199-
lazy_xp_function(static_params, static_argnums=[1]) # type: ignore[arg-type]
199+
lazy_xp_function(static_params, static_argnums=[1]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
200200

201201

202202
try:

0 commit comments

Comments
 (0)