1717# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
1818# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919
20+ import copyreg
2021import enum
2122import functools as ft
2223import importlib .util
@@ -317,6 +318,10 @@ def _check_shape(
317318 assert False
318319
319320
321+ def _pickle_array_annotation (x : type ["AbstractArray" ]):
322+ return x .dtype .__getitem__ , ((x .array_type , x .dim_str ),)
323+
324+
320325@ft .lru_cache (maxsize = None )
321326def _make_metaclass (base_metaclass ):
322327 class MetaAbstractArray (_MetaAbstractArray , base_metaclass ):
@@ -338,6 +343,8 @@ def __eq__(cls, other):
338343 def __hash__ (cls ):
339344 return id (cls )
340345
346+ copyreg .pickle (MetaAbstractArray , _pickle_array_annotation )
347+
341348 return MetaAbstractArray
342349
343350
@@ -358,11 +365,15 @@ class for `Float32[Array, "foo"]`.
358365 you can check `issubclass(annotation, jaxtyping.AbstractArray)`.
359366 """
360367
368+ # This is what it was defined with.
369+ dtype : type ["AbstractDtype" ]
361370 array_type : Any
371+ dim_str : str
372+
373+ # This is the processed information we need for later typechecking.
362374 dtypes : list [str ]
363375 dims : tuple [_AbstractDimOrVariadicDim , ...]
364376 index_variadic : Optional [int ]
365- dim_str : str
366377
367378
368379_not_made = object ()
@@ -595,8 +606,8 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
595606 return (array_type , name , dtypes , dims , index_variadic , dim_str )
596607
597608
598- def _make_array (* args , ** kwargs ):
599- out = _make_array_cached (* args , ** kwargs )
609+ def _make_array (x , dim_str , dtype ):
610+ out = _make_array_cached (x , dim_str , dtype . dtypes , dtype . __name__ )
600611
601612 if type (out ) is tuple :
602613 array_type , name , dtypes , dims , index_variadic , dim_str = out
@@ -610,11 +621,12 @@ def _make_array(*args, **kwargs):
610621 name ,
611622 (AbstractArray ,) if array_type is Any else (array_type , AbstractArray ),
612623 dict (
624+ dtype = dtype ,
613625 array_type = array_type ,
626+ dim_str = dim_str ,
614627 dtypes = dtypes ,
615628 dims = dims ,
616629 index_variadic = index_variadic ,
617- dim_str = dim_str ,
618630 ),
619631 )
620632 if getattr (typing , "GENERATING_DOCUMENTATION" , False ):
@@ -654,10 +666,7 @@ def __getitem__(cls, item: tuple[Any, str]):
654666 array_type = bound
655667 del item
656668 if get_origin (array_type ) in _union_types :
657- out = [
658- _make_array (x , dim_str , cls .dtypes , cls .__name__ )
659- for x in get_args (array_type )
660- ]
669+ out = [_make_array (x , dim_str , cls ) for x in get_args (array_type )]
661670 out = tuple (x for x in out if x is not _not_made )
662671 if len (out ) == 0 :
663672 raise ValueError ("Invalid jaxtyping type annotation." )
@@ -666,7 +675,7 @@ def __getitem__(cls, item: tuple[Any, str]):
666675 else :
667676 out = Union [out ]
668677 else :
669- out = _make_array (array_type , dim_str , cls . dtypes , cls . __name__ )
678+ out = _make_array (array_type , dim_str , cls )
670679 if out is _not_made :
671680 raise ValueError ("Invalid jaxtyping type annotation." )
672681 return out
0 commit comments