55import io
66import math
77import pickle
8- from collections .abc import Generator , Iterable
8+ from collections .abc import Callable , Generator , Hashable , Iterable
9+ from functools import wraps
910from types import ModuleType , NoneType
10- from typing import TYPE_CHECKING , Any , Literal , TypeVar , cast
11+ from typing import TYPE_CHECKING , Any , Literal , ParamSpec , TypeVar , cast
1112
1213from . import _compat
1314from ._compat import (
@@ -29,6 +30,7 @@ def override(func):
2930 return func
3031
3132
33+ P = ParamSpec ("P" )
3234T = TypeVar ("T" )
3335
3436
@@ -38,6 +40,7 @@ def override(func):
3840 "eager_shape" ,
3941 "in1d" ,
4042 "is_python_scalar" ,
43+ "jax_autojit" ,
4144 "mean" ,
4245 "meta_namespace" ,
4346 "pickle_without" ,
@@ -368,7 +371,7 @@ def pickle_without(
368371 >>> class A:
369372 ... def __repr__(self):
370373 ... return "<A>"
371- >>> obj = {1: A(), 2: [A(), NS(), A()]} # Any serializable object
374+ >>> obj = {1: A(), 2: [A(), NS(), A()]}
372375 >>> pik, instances, unpickleable = pickle_without(obj, A)
373376 >>> instances, unpickleable
374377 ([<A>, <A>, <A>], [<NS>])
@@ -396,7 +399,6 @@ class Pickler(pickle.Pickler): # numpydoc ignore=GL01,RT01
396399
397400 @override
398401 def persistent_id (self , obj : object ) -> object : # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
399-
400402 # Fast exit in case of basic builtin types.
401403 # Note that basic collections (tuple, list, etc.) are in this;
402404 # persistent_id() will be called again with their contents.
@@ -416,7 +418,9 @@ def persistent_id(self, obj: object) -> object: # pyright: ignore[reportIncompa
416418 return id_ , 0
417419
418420 try :
419- _ = obj .__reduce__ ()
421+ # a class that defines __slots__ without defining __getstate__
422+ # cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
423+ _ = obj .__reduce_ex__ (pickle .HIGHEST_PROTOCOL )
420424 except Exception : # pylint: disable=broad-exception-caught
421425 pass
422426 else : # Can be pickled
@@ -425,7 +429,7 @@ def persistent_id(self, obj: object) -> object: # pyright: ignore[reportIncompa
425429
426430 # May be a global function, which can be pickled
427431 try :
428- _ = pickle .dumps (obj )
432+ _ = pickle .dumps (obj , protocol = pickle . HIGHEST_PROTOCOL )
429433 except Exception : # pylint: disable=broad-exception-caught
430434 pass
431435 else : # Can be pickled
@@ -438,7 +442,7 @@ def persistent_id(self, obj: object) -> object: # pyright: ignore[reportIncompa
438442 return id_ , 1
439443
440444 f = io .BytesIO ()
441- p = Pickler (f )
445+ p = Pickler (f , protocol = pickle . HIGHEST_PROTOCOL )
442446 p .dump (obj )
443447 return f .getvalue (), tuple (instances ), tuple (unpickleable )
444448
@@ -480,7 +484,7 @@ def unpickle_without( # type: ignore[explicit-any]
480484 quietly ignored.
481485 """
482486 iters = iter (instances ), iter (unpickleable )
483- seen : dict [int , object ] = {}
487+ seen : dict [tuple [ int , int ] , object ] = {}
484488
485489 class Unpickler (pickle .Unpickler ): # numpydoc ignore=GL01,RT01
486490 """
@@ -509,3 +513,72 @@ def persistent_load(self, pid: tuple[int, int]) -> object: # pyright: ignore[re
509513
510514 f = io .BytesIO (pik )
511515 return Unpickler (f ).load ()
516+
517+
518+ def jax_autojit (
519+ func : Callable [P , T ],
520+ ) -> Callable [P , T ]: # numpydoc ignore=PR01,RT01,SS03
521+ """
522+ Wrap `func` with ``jax.jit``, with the following differences:
523+
524+ - Array-like arguments and return values are not automatically converted to
525+ ``jax.Array`` objects.
526+ - All non-array arguments are automatically treated as static.
527+ Unlike ``jax.jit``, static arguments must be either hashable or serializable with
528+ ``pickle``.
529+ - Unlike ``jax.jit``, non-array arguments and return values are not limited to
530+ tuple/list/dict, but can be any object serializable with ``pickle``.
531+ - Automatically descend into non-array arguments and find ``jax.Array`` objects
532+ inside them, then rebuild the arguments when entering `func`, swapping the JAX
533+ concrete arrays with tracer objects.
534+ - Automatically descend into non-array return values and find ``jax.Array`` objects
535+ inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
536+ tracer objects with concrete arrays.
537+ """
538+ import jax
539+
540+ # {
541+ # jit_cache_key(args_pik, args_arrays, args_unpickleable):
542+ # (res_pik, res_unpickleable)
543+ # }
544+ static_return_values : dict [Hashable , tuple [bytes , tuple [object , ...]]] = {}
545+
546+ def jit_cache_key ( # type: ignore[no-any-unimported] # numpydoc ignore=GL08
547+ args_pik : bytes ,
548+ args_arrays : tuple [jax .Array , ...], # pyright: ignore[reportUnknownParameterType]
549+ args_unpickleable : tuple [Hashable , ...],
550+ ) -> Hashable :
551+ return (
552+ args_pik ,
553+ tuple ((arr .shape , arr .dtype ) for arr in args_arrays ), # pyright: ignore[reportUnknownArgumentType]
554+ args_unpickleable ,
555+ )
556+
557+ def inner ( # type: ignore[no-any-unimported] # pyright: ignore[reportUnknownParameterType]
558+ args_pik : bytes ,
559+ args_arrays : tuple [jax .Array , ...], # pyright: ignore[reportUnknownParameterType]
560+ args_unpickleable : tuple [Hashable , ...],
561+ ) -> tuple [jax .Array , ...]: # numpydoc ignore=GL08
562+ args , kwargs = unpickle_without (args_pik , args_arrays , args_unpickleable ) # pyright: ignore[reportUnknownArgumentType]
563+ res = func (* args , ** kwargs ) # pyright: ignore[reportCallIssue]
564+ res_pik , res_arrays , res_unpickleable = pickle_without (res , jax .Array ) # pyright: ignore[reportUnknownArgumentType]
565+ key = jit_cache_key (args_pik , args_arrays , args_unpickleable )
566+ val = res_pik , res_unpickleable
567+ prev = static_return_values .setdefault (key , val )
568+ assert prev == val , "cache key collision"
569+ return res_arrays
570+
571+ jitted = jax .jit (inner , static_argnums = (0 , 2 ))
572+
573+ @wraps (func )
574+ def outer (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
575+ args_pik , args_arrays , args_unpickleable = pickle_without (
576+ (args , kwargs ),
577+ jax .Array , # pyright: ignore[reportUnknownArgumentType]
578+ )
579+ res_arrays = jitted (args_pik , args_arrays , args_unpickleable )
580+ key = jit_cache_key (args_pik , args_arrays , args_unpickleable )
581+ res_pik , res_unpickleable = static_return_values [key ]
582+ return unpickle_without (res_pik , res_arrays , res_unpickleable ) # pyright: ignore[reportUnknownArgumentType]
583+
584+ return outer
0 commit comments