22
33from  __future__ import  annotations 
44
5- import  copyreg 
65import  io 
76import  math 
87import  pickle 
9- from  collections .abc  import  Callable , Generator , Iterable ,  Iterator 
10- from  contextvars  import  ContextVar 
11- from  types  import  ModuleType 
12- from  typing  import  TYPE_CHECKING , Any , TypeVar , cast 
8+ from  collections .abc  import  Callable , Generator , Hashable ,  Iterable 
9+ from  functools  import  wraps 
10+ from  types  import  ModuleType ,  NoneType 
11+ from  typing  import  TYPE_CHECKING , Any , Literal ,  ParamSpec ,  TypeVar , cast 
1312
1413from  . import  _compat 
1514from  ._compat  import  (
2322from  ._typing  import  Array 
2423
2524if  TYPE_CHECKING :  # pragma: no cover 
26-     # TODO import from typing (requires Python >=3.13) 
27-     from  typing_extensions  import  TypeIs 
25+     # TODO import from typing (requires Python >=3.12 and >=3.13) 
26+     from  typing_extensions  import  TypeIs , override 
27+ else :
2828
29+     def  override (func ):
30+         return  func 
31+ 
32+ 
33+ P  =  ParamSpec ("P" )
2934T  =  TypeVar ("T" )
3035
3136
3540    "eager_shape" ,
3641    "in1d" ,
3742    "is_python_scalar" ,
43+     "jax_autojit" ,
3844    "mean" ,
3945    "meta_namespace" ,
4046    "pickle_without" ,
@@ -316,48 +322,39 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
316322    return  out 
317323
318324
319- # Helper of ``extract_objects`` and ``repack_objects`` 
320- _repacking_objects : ContextVar [Iterator [object ]] =  ContextVar ("_repacking_objects" )
321- 
322- 
323- def  _expand () ->  object :  # numpydoc ignore=RT01 
324-     """ 
325-     Helper of ``extract_objects`` and ``repack_objects``. 
326- 
327-     Inverse of the reducer function. 
328- 
329-     Notes 
330-     ----- 
331-     This function must be global in order to be picklable. 
332-     """ 
333-     try :
334-         return  next (_repacking_objects .get ())
335-     except  StopIteration :
336-         msg  =  "Not enough objects to repack" 
337-         raise  ValueError (msg )
325+ _BASIC_TYPES  =  frozenset ((
326+     NoneType , bool , int , float , complex , str , bytes , bytearray ,
327+     list , tuple , dict , set , frozenset , range , slice ,
328+ ))  # fmt: skip 
338329
339330
340- def  pickle_without (obj : object , * classes : type [T ]) ->  tuple [bytes , list [T ]]:
331+ def  pickle_without (
332+     obj : object , cls : type [T ] |  tuple [type [T ], ...] =  ()
333+ ) ->  tuple [bytes , tuple [T , ...], tuple [object , ...]]:
341334    """ 
342-     Variant of ``pickle.dumps`` that extracts inner objects. 
335+     Variant of ``pickle.dumps`` that always succeeds and  extracts inner objects. 
343336
344337    Conceptually, this is similar to passing the ``buffer_callback`` argument to 
345-     ``pickle.dumps``, but instead of extracting buffers it extracts entire objects. 
338+     ``pickle.dumps``, but instead of extracting buffers it extracts entire objects, 
339+     which are either not serializable with ``pickle`` (e.g. local classes or functions) 
340+     or instances of an explicit list of classes. 
346341
347342    Parameters 
348343    ---------- 
349344    obj : object 
350345        The object to pickle. 
351-     *classes  : type 
352-         One or more  classes to extract from the object. 
346+     cls  : type | tuple[type, ...], optional  
347+         One or multiple  classes to extract from the object. 
353348        The instances of these classes inside ``obj`` will not be pickled. 
354349
355350    Returns 
356351    ------- 
357352    bytes 
358353        The pickled object. Must be unpickled with :func:`unpickle_without`. 
359-     list 
360-         All instances of ``classes`` found inside ``obj`` (not pickled). 
354+     tuple 
355+         All instances of ``cls`` found inside ``obj`` (not pickled). 
356+     tuple 
357+         All other objects which failed to pickle. 
361358
362359    See Also 
363360    -------- 
@@ -366,75 +363,221 @@ def pickle_without(obj: object, *classes: type[T]) -> tuple[bytes, list[T]]:
366363
367364    Examples 
368365    -------- 
366+     >>> class NS: 
367+     ...     def __repr__(self): 
368+     ...         return "<NS>" 
369+     ...     def __reduce__(self): 
370+     ...         assert False 
369371    >>> class A: 
370372    ...     def __repr__(self): 
371373    ...         return "<A>" 
372-     ...     def __reduce__(self): 
373-     ...         assert False, "Not serializable" 
374-     >>> obj = {1: A(), 2: [A(), A()]}  # Any serializable object 
375-     >>> pik, extracted = pickle_without(obj, A) 
376-     >>> extracted 
377-     [<A>, <A>, <A>] 
378-     >>> unpickle_without(pik, extracted) 
379-     {1: <A>, 2: [<A>, <A>]} 
374+     >>> obj = {1: A(), 2: [A(), NS(), A()]} 
375+     >>> pik, instances, unpickleable = pickle_without(obj, A) 
376+     >>> instances, unpickleable 
377+     ([<A>, <A>, <A>], [<NS>]) 
378+     >>> unpickle_without(pik, instances, unpickleable) 
379+     {1: <A>, 2: [<A>, <NS>, <A>]} 
380380
381381    This can be also used to hot-swap inner objects; the only constraint is that 
382382    the number of objects in and out must be the same: 
383383
384384    >>> class B: 
385385    ...     def __repr__(self): return "<B>" 
386-     >>> unpickle_without(pik, [B(), B(), B()]) 
387-     {1: <B>, 2: [<B>, <B>]} 
386+     >>> unpickle_without(pik, [B(), B(), B()], [NS()] ) 
387+     {1: <B>, 2: [<B>, <NS>, < B>]} 
388388    """ 
389-     extracted  =  []
390- 
391-     def  reduce (x : T ) ->  tuple [Callable [[], object ], tuple [()]]:  # numpydoc ignore=GL08 
392-         extracted .append (x )
393-         return  _expand , ()
389+     instances : list [T ] =  []
390+     unpickleable : list [object ] =  []
391+     seen : dict [int , Literal [0 , 1 , None ]] =  {}
392+ 
393+     class  Pickler (pickle .Pickler ):  # numpydoc ignore=GL01,RT01 
394+         """Override pickle.Pickler.persistent_id. 
395+ 
396+         TODO consider moving to top-level scope to allow for 
397+         the full Pickler API to be used. 
398+         """ 
399+ 
400+         @override  
401+         def  persistent_id (self , obj : object ) ->  object :  # pyright: ignore[reportIncompatibleMethodOverride]  # numpydoc ignore=GL08 
402+             # Fast exit in case of basic builtin types. 
403+             # Note that basic collections (tuple, list, etc.) are in this; 
404+             # persistent_id() will be called again with their contents. 
405+             if  type (obj ) in  _BASIC_TYPES :  # No subclasses! 
406+                 return  None 
407+ 
408+             id_  =  id (obj )
409+             try :
410+                 kind  =  seen [id_ ]
411+                 return  None  if  kind  is  None  else  (id_ , kind )
412+             except  KeyError :
413+                 pass 
414+ 
415+             if  isinstance (obj , cls ):
416+                 instances .append (obj )  # type: ignore[arg-type] 
417+                 seen [id_ ] =  0 
418+                 return  id_ , 0 
419+ 
420+             for  func  in  (
421+                 # Note: a class that defines __slots__ without defining __getstate__ 
422+                 # cannot be pickled with __reduce__(), but can with __reduce_ex__(5) 
423+                 lambda : obj .__reduce_ex__ (pickle .HIGHEST_PROTOCOL ),
424+                 lambda : obj .__reduce__ (),
425+                 # Global functions don't have __reduce__, which can be pickled 
426+                 lambda : pickle .dumps (obj , protocol = pickle .HIGHEST_PROTOCOL ),
427+             ):
428+                 try :
429+                     # a class that defines __slots__ without defining __getstate__ 
430+                     # cannot be pickled with __reduce__(), but can with __reduce_ex__(5) 
431+                     func ()
432+                 except  Exception :  # pylint: disable=broad-exception-caught 
433+                     pass 
434+                 else :  # Can be pickled 
435+                     seen [id_ ] =  None 
436+                     return  None 
437+ 
438+             # Can't be pickled 
439+             unpickleable .append (obj )
440+             seen [id_ ] =  1 
441+             return  id_ , 1 
394442
395443    f  =  io .BytesIO ()
396-     p  =  pickle .Pickler (f )
397- 
398-     # Override the reducer for the given classes and all their 
399-     # subclasses (recursively). 
400-     p .dispatch_table  =  copyreg .dispatch_table .copy ()
401-     subclasses  =  list (classes )
402-     while  subclasses :
403-         cls  =  subclasses .pop ()
404-         p .dispatch_table [cls ] =  reduce 
405-         subclasses .extend (cls .__subclasses__ ())
406- 
444+     p  =  Pickler (f , protocol = pickle .HIGHEST_PROTOCOL )
407445    p .dump (obj )
446+     return  f .getvalue (), tuple (instances ), tuple (unpickleable )
408447
409-     return  f .getvalue (), extracted 
410448
411- 
412- def  unpickle_without (pik : bytes , objects : Iterable [object ], / ) ->  Any :  # type: ignore[explicit-any] 
449+ def  unpickle_without (  # type: ignore[explicit-any] 
450+     pik : bytes ,
451+     instances : Iterable [object ],
452+     unpickleable : Iterable [object ],
453+     / ,
454+ ) ->  Any :
413455    """ 
414456    Variant of ``pickle.loads``, reverse of ``pickle_without``. 
415457
416458    Parameters 
417459    ---------- 
418460    pik : bytes 
419461        The pickled object generated by ``pickle_without``. 
420-     objects  : Iterable 
421-         The objects to be reinserted into the unpickled object.  
422-         Must  be the at least  the same number of elements as the ones extracted by  
423-         ``pickle_without``, but does not need to be the same objects or even the  
424-         same types of  objects. Excess objects, if any, won't be inserted . 
462+     instances  : Iterable[object]  
463+         Instances of the class or classes explicitly passed to ``pickle_without``,  
464+         to  be reinserted into  the unpickled object.  
465+     unpickleable : Iterable[object]  
466+         The  objects that failed to pickle, as returned by ``pickle_without`` . 
425467
426468    Returns 
427469    ------- 
428470    object 
429-         The unpickled object, with the objects in ``objects`` inserted back into it . 
471+         The unpickled object. 
430472
431473    See Also 
432474    -------- 
433475    pickle_without : Serializing function. 
434476    pickle.loads : Standard unpickle function. 
477+ 
478+     Notes 
479+     ----- 
480+     The second and third parameter of this function must yield at least the same number 
481+     of elements as the ones returned by ``pickle_without``, but do not need to be the 
482+     same objects, or even the same types of objects. Excess objects, if any, will be 
483+     quietly ignored. 
484+     """ 
485+     iters  =  iter (instances ), iter (unpickleable )
486+     seen : dict [tuple [int , int ], object ] =  {}
487+ 
488+     class  Unpickler (pickle .Unpickler ):  # numpydoc ignore=GL01,RT01 
489+         """ 
490+         Override pickle.Pickler.persistent_load. 
491+ 
492+         TODO consider moving to top-level scope to allow for 
493+         the full Unpickler API to be used. 
494+         """ 
495+ 
496+         @override  
497+         def  persistent_load (self , pid : tuple [int , int ]) ->  object :  # pyright: ignore[reportIncompatibleMethodOverride]  # numpydoc ignore=GL08 
498+             try :
499+                 return  seen [pid ]
500+             except  KeyError :
501+                 pass 
502+ 
503+             _ , kind  =  pid 
504+             try :
505+                 obj  =  next (iters [kind ])
506+             except  StopIteration  as  e :
507+                 msg  =  "Not enough objects to unpickle" 
508+                 raise  ValueError (msg ) from  e 
509+ 
510+             seen [pid ] =  obj 
511+             return  obj 
512+ 
513+     f  =  io .BytesIO (pik )
514+     return  Unpickler (f ).load ()
515+ 
516+ 
517+ def  jax_autojit (
518+     func : Callable [P , T ],
519+ ) ->  Callable [P , T ]:  # numpydoc ignore=PR01,RT01,SS03 
520+     """ 
521+     Wrap `func` with ``jax.jit``, with the following differences: 
522+ 
523+     - Array-like arguments and return values are not automatically converted to 
524+       ``jax.Array`` objects. 
525+     - All non-array arguments are automatically treated as static. 
526+       Unlike ``jax.jit``, static arguments must be either hashable or serializable with 
527+       ``pickle``. 
528+     - Unlike ``jax.jit``, non-array arguments and return values are not limited to 
529+       tuple/list/dict, but can be any object serializable with ``pickle``. 
530+     - Automatically descend into non-array arguments and find ``jax.Array`` objects 
531+       inside them, then rebuild the arguments when entering `func`, swapping the JAX 
532+       concrete arrays with tracer objects. 
533+     - Automatically descend into non-array return values and find ``jax.Array`` objects 
534+       inside them, then rebuild them downstream of exiting the JIT, swapping the JAX 
535+       tracer objects with concrete arrays. 
435536    """ 
436-     tok  =  _repacking_objects .set (iter (objects ))
437-     try :
438-         return  pickle .loads (pik )
439-     finally :
440-         _repacking_objects .reset (tok )
537+     import  jax 
538+ 
539+     # { 
540+     #   jit_cache_key(args_pik, args_arrays, args_unpickleable): 
541+     #   (res_pik, res_unpickleable) 
542+     # } 
543+     static_return_values : dict [Hashable , tuple [bytes , tuple [object , ...]]] =  {}
544+ 
545+     def  jit_cache_key (  # type: ignore[no-any-unimported]  # numpydoc ignore=GL08 
546+         args_pik : bytes ,
547+         args_arrays : tuple [jax .Array , ...],  # pyright: ignore[reportUnknownParameterType] 
548+         args_unpickleable : tuple [Hashable , ...],
549+     ) ->  Hashable :
550+         return  (
551+             args_pik ,
552+             tuple ((arr .shape , arr .dtype ) for  arr  in  args_arrays ),  # pyright: ignore[reportUnknownArgumentType] 
553+             args_unpickleable ,
554+         )
555+ 
556+     def  inner (  # type: ignore[no-any-unimported]  # pyright: ignore[reportUnknownParameterType] 
557+         args_pik : bytes ,
558+         args_arrays : tuple [jax .Array , ...],  # pyright: ignore[reportUnknownParameterType] 
559+         args_unpickleable : tuple [Hashable , ...],
560+     ) ->  tuple [jax .Array , ...]:  # numpydoc ignore=GL08 
561+         args , kwargs  =  unpickle_without (args_pik , args_arrays , args_unpickleable )  # pyright: ignore[reportUnknownArgumentType] 
562+         res  =  func (* args , ** kwargs )  # pyright: ignore[reportCallIssue] 
563+         res_pik , res_arrays , res_unpickleable  =  pickle_without (res , jax .Array )  # pyright: ignore[reportUnknownArgumentType] 
564+         key  =  jit_cache_key (args_pik , args_arrays , args_unpickleable )
565+         val  =  res_pik , res_unpickleable 
566+         prev  =  static_return_values .setdefault (key , val )
567+         assert  prev  ==  val , "cache key collision" 
568+         return  res_arrays 
569+ 
570+     jitted  =  jax .jit (inner , static_argnums = (0 , 2 ))
571+ 
572+     @wraps (func ) 
573+     def  outer (* args : P .args , ** kwargs : P .kwargs ) ->  T :  # numpydoc ignore=GL08 
574+         args_pik , args_arrays , args_unpickleable  =  pickle_without (
575+             (args , kwargs ),
576+             jax .Array ,  # pyright: ignore[reportUnknownArgumentType] 
577+         )
578+         res_arrays  =  jitted (args_pik , args_arrays , args_unpickleable )
579+         key  =  jit_cache_key (args_pik , args_arrays , args_unpickleable )
580+         res_pik , res_unpickleable  =  static_return_values [key ]
581+         return  unpickle_without (res_pik , res_arrays , res_unpickleable )  # pyright: ignore[reportUnknownArgumentType] 
582+ 
583+     return  outer 
0 commit comments