Skip to content

Commit 6813074

Browse files
Added support for pickling jaxtyping annotations.
1 parent bd84aed commit 6813074

File tree

3 files changed

+42
-23
lines changed

3 files changed

+42
-23
lines changed

jaxtyping/_array_types.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
1818
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919

20+
import copyreg
2021
import enum
2122
import functools as ft
2223
import importlib.util
@@ -317,6 +318,10 @@ def _check_shape(
317318
assert False
318319

319320

321+
def _pickle_array_annotation(x: type["AbstractArray"]):
322+
return x.dtype.__getitem__, ((x.array_type, x.dim_str),)
323+
324+
320325
@ft.lru_cache(maxsize=None)
321326
def _make_metaclass(base_metaclass):
322327
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
@@ -338,6 +343,8 @@ def __eq__(cls, other):
338343
def __hash__(cls):
339344
return id(cls)
340345

346+
copyreg.pickle(MetaAbstractArray, _pickle_array_annotation)
347+
341348
return MetaAbstractArray
342349

343350

@@ -358,11 +365,15 @@ class for `Float32[Array, "foo"]`.
358365
you can check `issubclass(annotation, jaxtyping.AbstractArray)`.
359366
"""
360367

368+
# This is what it was defined with.
369+
dtype: type["AbstractDtype"]
361370
array_type: Any
371+
dim_str: str
372+
373+
# This is the processed information we need for later typechecking.
362374
dtypes: list[str]
363375
dims: tuple[_AbstractDimOrVariadicDim, ...]
364376
index_variadic: Optional[int]
365-
dim_str: str
366377

367378

368379
_not_made = object()
@@ -595,8 +606,8 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
595606
return (array_type, name, dtypes, dims, index_variadic, dim_str)
596607

597608

598-
def _make_array(*args, **kwargs):
599-
out = _make_array_cached(*args, **kwargs)
609+
def _make_array(x, dim_str, dtype):
610+
out = _make_array_cached(x, dim_str, dtype.dtypes, dtype.__name__)
600611

601612
if type(out) is tuple:
602613
array_type, name, dtypes, dims, index_variadic, dim_str = out
@@ -610,11 +621,12 @@ def _make_array(*args, **kwargs):
610621
name,
611622
(AbstractArray,) if array_type is Any else (array_type, AbstractArray),
612623
dict(
624+
dtype=dtype,
613625
array_type=array_type,
626+
dim_str=dim_str,
614627
dtypes=dtypes,
615628
dims=dims,
616629
index_variadic=index_variadic,
617-
dim_str=dim_str,
618630
),
619631
)
620632
if getattr(typing, "GENERATING_DOCUMENTATION", False):
@@ -654,10 +666,7 @@ def __getitem__(cls, item: tuple[Any, str]):
654666
array_type = bound
655667
del item
656668
if get_origin(array_type) in _union_types:
657-
out = [
658-
_make_array(x, dim_str, cls.dtypes, cls.__name__)
659-
for x in get_args(array_type)
660-
]
669+
out = [_make_array(x, dim_str, cls) for x in get_args(array_type)]
661670
out = tuple(x for x in out if x is not _not_made)
662671
if len(out) == 0:
663672
raise ValueError("Invalid jaxtyping type annotation.")
@@ -666,7 +675,7 @@ def __getitem__(cls, item: tuple[Any, str]):
666675
else:
667676
out = Union[out]
668677
else:
669-
out = _make_array(array_type, dim_str, cls.dtypes, cls.__name__)
678+
out = _make_array(array_type, dim_str, cls)
670679
if out is _not_made:
671680
raise ValueError("Invalid jaxtyping type annotation.")
672681
return out

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "jaxtyping"
3-
version = "0.2.37"
3+
version = "0.2.38"
44
description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays."
55
readme = "README.md"
66
requires-python =">=3.10"

test/test_serialisation.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pickle
2+
13
import cloudpickle
24
import numpy as np
35

@@ -7,19 +9,27 @@
79
except ImportError:
810
torch = None
911

10-
from jaxtyping import AbstractArray, Array, Shaped
12+
from jaxtyping import AbstractArray, Array, Float, Shaped
1113

1214

1315
def test_pickle():
14-
x = cloudpickle.dumps(Shaped[Array, ""])
15-
cloudpickle.loads(x)
16-
17-
y = cloudpickle.dumps(AbstractArray)
18-
cloudpickle.loads(y)
19-
20-
z = cloudpickle.dumps(Shaped[np.ndarray, ""])
21-
cloudpickle.loads(z)
22-
23-
if torch is not None:
24-
w = cloudpickle.dumps(Shaped[torch.Tensor, ""])
25-
cloudpickle.loads(w)
16+
for p in (pickle, cloudpickle):
17+
x = p.dumps(Shaped[Array, ""])
18+
y = p.loads(x)
19+
assert y.dtype is Shaped
20+
assert y.dim_str == ""
21+
22+
x = p.dumps(AbstractArray)
23+
y = p.loads(x)
24+
assert y is AbstractArray
25+
26+
x = p.dumps(Shaped[np.ndarray, "3 4 hi"])
27+
y = p.loads(x)
28+
assert y.dtype is Shaped
29+
assert y.dim_str == "3 4 hi"
30+
31+
if torch is not None:
32+
x = p.dumps(Float[torch.Tensor, "batch length"])
33+
y = p.loads(x)
34+
assert y.dtype is Float
35+
assert y.dim_str == "batch length"

0 commit comments

Comments
 (0)