22
33from __future__ import annotations
44
5- import copyreg
65import io
76import math
87import pickle
9- from collections .abc import Callable , Generator , Iterable , Iterator
10- from contextvars import ContextVar
8+ from collections .abc import Generator , Iterable
119from types import ModuleType
12- from typing import TYPE_CHECKING , Any , TypeVar , cast
10+ from typing import TYPE_CHECKING , Any , Literal , TypeVar , cast
1311
1412from . import _compat
1513from ._compat import (
2321from ._typing import Array
2422
2523if TYPE_CHECKING : # pragma: no cover
26- # TODO import from typing (requires Python >=3.13)
27- from typing_extensions import TypeIs
24+ # TODO import from typing (requires Python >=3.12 and >=3.13)
25+ from typing_extensions import TypeIs , override
26+ else :
27+
28+ def override (func ):
29+ return func
30+
2831
2932T = TypeVar ("T" )
3033
@@ -316,48 +319,33 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
316319 return out
317320
318321
319- # Helper of ``extract_objects`` and ``repack_objects``
320- _repacking_objects : ContextVar [Iterator [object ]] = ContextVar ("_repacking_objects" )
321-
322-
323- def _expand () -> object : # numpydoc ignore=RT01
324- """
325- Helper of ``extract_objects`` and ``repack_objects``.
326-
327- Inverse of the reducer function.
328-
329- Notes
330- -----
331- This function must be global in order to be picklable.
332- """
333- try :
334- return next (_repacking_objects .get ())
335- except StopIteration :
336- msg = "Not enough objects to repack"
337- raise ValueError (msg )
338-
339-
340- def pickle_without (obj : object , * classes : type [T ]) -> tuple [bytes , list [T ]]:
322+ def pickle_without (
323+ obj : object , cls : type [T ] | tuple [type [T ], ...] = ()
324+ ) -> tuple [bytes , tuple [T , ...], tuple [object , ...]]:
341325 """
342- Variant of ``pickle.dumps`` that extracts inner objects.
326+ Variant of ``pickle.dumps`` that always succeeds and extracts inner objects.
343327
344328 Conceptually, this is similar to passing the ``buffer_callback`` argument to
345- ``pickle.dumps``, but instead of extracting buffers it extracts entire objects.
329+ ``pickle.dumps``, but instead of extracting buffers it extracts entire objects,
330+ which are either not serializable with ``pickle`` (e.g. local classes or functions)
331+ or instances of an explicit list of classes.
346332
347333 Parameters
348334 ----------
349335 obj : object
350336 The object to pickle.
351- *classes : type
352- One or more classes to extract from the object.
337+ cls : type | tuple[type, ...], optional
338+ One or multiple classes to extract from the object.
353339 The instances of these classes inside ``obj`` will not be pickled.
354340
355341 Returns
356342 -------
357343 bytes
358344 The pickled object. Must be unpickled with :func:`unpickle_without`.
359- list
360- All instances of ``classes`` found inside ``obj`` (not pickled).
345+ tuple
346+ All instances of ``cls`` found inside ``obj`` (not pickled).
347+ tuple
348+ All other objects which failed to pickle.
361349
362350 See Also
363351 --------
@@ -366,75 +354,144 @@ def pickle_without(obj: object, *classes: type[T]) -> tuple[bytes, list[T]]:
366354
367355 Examples
368356 --------
357+ >>> class NS:
358+ ... def __repr__(self):
359+ ... return "<NS>"
360+ ... def __reduce__(self):
361+ ... assert False
369362 >>> class A:
370363 ... def __repr__(self):
371364 ... return "<A>"
372- ... def __reduce__(self):
373- ... assert False, "Not serializable"
374- >>> obj = {1: A(), 2: [A(), A()]} # Any serializable object
375- >>> pik, extracted = pickle_without(obj, A)
376- >>> extracted
377- [<A>, <A>, <A>]
378- >>> unpickle_without(pik, extracted)
379- {1: <A>, 2: [<A>, <A>]}
365+ >>> obj = {1: A(), 2: [A(), NS(), A()]} # Any serializable object
366+ >>> pik, instances, unpickleable = pickle_without(obj, A)
367+ >>> instances, unpickleable
368+ ([<A>, <A>, <A>], [<NS>])
369+ >>> unpickle_without(pik, instances, unpickleable)
370+ {1: <A>, 2: [<A>, <NS>, <A>]}
380371
381372 This can be also used to hot-swap inner objects; the only constraint is that
382373 the number of objects in and out must be the same:
383374
384375 >>> class B:
385376 ... def __repr__(self): return "<B>"
386- >>> unpickle_without(pik, [B(), B(), B()])
387- {1: <B>, 2: [<B>, <B>]}
377+ >>> unpickle_without(pik, [B(), B(), B()], [NS()] )
378+ {1: <B>, 2: [<B>, <NS>, < B>]}
388379 """
389- extracted = []
390-
391- def reduce (x : T ) -> tuple [Callable [[], object ], tuple [()]]: # numpydoc ignore=GL08
392- extracted .append (x )
393- return _expand , ()
380+ instances : list [T ] = []
381+ unpickleable : list [object ] = []
382+ seen : dict [int , Literal [0 , 1 , None ]] = {}
383+
384+ class Pickler (pickle .Pickler ): # numpydoc ignore=GL01,RT01
385+ """Override pickle.Pickler.persistent_id.
386+
387+ TODO consider moving to top-level scope to allow for
388+ the full Pickler API to be used.
389+ """
390+
391+ @override
392+ def persistent_id (self , obj : object ) -> object : # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
393+ id_ = id (obj )
394+ try :
395+ return seen [id_ ]
396+ except KeyError :
397+ pass
398+
399+ if isinstance (obj , cls ):
400+ instances .append (obj ) # type: ignore[arg-type]
401+ seen [id_ ] = 0
402+ return id_ , 0
403+
404+ try :
405+ _ = obj .__reduce__ ()
406+ except Exception : # pylint: disable=broad-exception-caught
407+ pass
408+ else : # Can be pickled
409+ seen [id_ ] = None
410+ return None
411+
412+ # May be a global function, which can be pickled
413+ try :
414+ _ = pickle .dumps (obj )
415+ except Exception : # pylint: disable=broad-exception-caught
416+ pass
417+ else : # Can be pickled
418+ seen [id_ ] = None
419+ return None
420+
421+ # Can't be pickled
422+ unpickleable .append (obj )
423+ seen [id_ ] = 1
424+ return id_ , 1
394425
395426 f = io .BytesIO ()
396- p = pickle .Pickler (f )
397-
398- # Override the reducer for the given classes and all their
399- # subclasses (recursively).
400- p .dispatch_table = copyreg .dispatch_table .copy ()
401- subclasses = list (classes )
402- while subclasses :
403- cls = subclasses .pop ()
404- p .dispatch_table [cls ] = reduce
405- subclasses .extend (cls .__subclasses__ ())
406-
427+ p = Pickler (f )
407428 p .dump (obj )
429+ return f .getvalue (), tuple (instances ), tuple (unpickleable )
408430
409- return f .getvalue (), extracted
410431
411-
412- def unpickle_without (pik : bytes , objects : Iterable [object ], / ) -> Any : # type: ignore[explicit-any]
432+ def unpickle_without ( # type: ignore[explicit-any]
433+ pik : bytes ,
434+ instances : Iterable [object ],
435+ unpickleable : Iterable [object ],
436+ / ,
437+ ) -> Any :
413438 """
414439 Variant of ``pickle.loads``, reverse of ``pickle_without``.
415440
416441 Parameters
417442 ----------
418443 pik : bytes
419444 The pickled object generated by ``pickle_without``.
420- objects : Iterable
421- The objects to be reinserted into the unpickled object.
422- Must be the at least the same number of elements as the ones extracted by
423- ``pickle_without``, but does not need to be the same objects or even the
424- same types of objects. Excess objects, if any, won't be inserted .
445+ instances : Iterable[object]
446+ Instances of the class or classes explicitly passed to ``pickle_without``,
447+ to be reinserted into the unpickled object.
448+ unpickleable : Iterable[object]
449+ The objects that failed to pickle, as returned by ``pickle_without`` .
425450
426451 Returns
427452 -------
428453 object
429- The unpickled object, with the objects in ``objects`` inserted back into it .
454+ The unpickled object.
430455
431456 See Also
432457 --------
433458 pickle_without : Serializing function.
434459 pickle.loads : Standard unpickle function.
460+
461+ Notes
462+ -----
463+ The second and third parameter of this function must yield at least the same number
464+ of elements as the ones returned by ``pickle_without``, but do not need to be the
465+ same objects, or even the same types of objects. Excess objects, if any, will be
466+ quietly ignored.
435467 """
436- tok = _repacking_objects .set (iter (objects ))
437- try :
438- return pickle .loads (pik )
439- finally :
440- _repacking_objects .reset (tok )
468+ iters = iter (instances ), iter (unpickleable )
469+ seen : dict [int , object ] = {}
470+
471+ class Unpickler (pickle .Unpickler ): # numpydoc ignore=GL01,RT01
472+ """
473+ Override pickle.Pickler.persistent_load.
474+
475+ TODO consider moving to top-level scope to allow for
476+ the full Unpickler API to be used.
477+ """
478+
479+ @override
480+ def persistent_load (self , pid : tuple [int , int ]) -> object : # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
481+ prev_id , kind = pid
482+ try :
483+ return seen [prev_id ]
484+ except KeyError :
485+ pass
486+
487+ try :
488+ obj = next (iters [kind ])
489+ except StopIteration as e :
490+ msg = "Not enough objects to unpickle"
491+ raise ValueError (msg ) from e
492+
493+ seen [prev_id ] = obj
494+ return obj
495+
496+ f = io .BytesIO (pik )
497+ return Unpickler (f ).load ()
0 commit comments