Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import copyreg
import enum
import functools as ft
import importlib.util
Expand Down Expand Up @@ -317,6 +318,10 @@ def _check_shape(
assert False


def _pickle_array_annotation(x: type["AbstractArray"]):
return x.dtype.__getitem__, ((x.array_type, x.dim_str),)


@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
Expand All @@ -338,6 +343,8 @@ def __eq__(cls, other):
def __hash__(cls):
return id(cls)

copyreg.pickle(MetaAbstractArray, _pickle_array_annotation)

return MetaAbstractArray


Expand All @@ -358,11 +365,15 @@ class for `Float32[Array, "foo"]`.
you can check `issubclass(annotation, jaxtyping.AbstractArray)`.
"""

# This is what it was defined with.
dtype: type["AbstractDtype"]
array_type: Any
dim_str: str

# This is the processed information we need for later typechecking.
dtypes: list[str]
dims: tuple[_AbstractDimOrVariadicDim, ...]
index_variadic: Optional[int]
dim_str: str


_not_made = object()
Expand Down Expand Up @@ -595,8 +606,8 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
return (array_type, name, dtypes, dims, index_variadic, dim_str)


def _make_array(*args, **kwargs):
out = _make_array_cached(*args, **kwargs)
def _make_array(x, dim_str, dtype):
out = _make_array_cached(x, dim_str, dtype.dtypes, dtype.__name__)

if type(out) is tuple:
array_type, name, dtypes, dims, index_variadic, dim_str = out
Expand All @@ -610,11 +621,12 @@ def _make_array(*args, **kwargs):
name,
(AbstractArray,) if array_type is Any else (array_type, AbstractArray),
dict(
dtype=dtype,
array_type=array_type,
dim_str=dim_str,
dtypes=dtypes,
dims=dims,
index_variadic=index_variadic,
dim_str=dim_str,
),
)
if getattr(typing, "GENERATING_DOCUMENTATION", False):
Expand Down Expand Up @@ -654,10 +666,7 @@ def __getitem__(cls, item: tuple[Any, str]):
array_type = bound
del item
if get_origin(array_type) in _union_types:
out = [
_make_array(x, dim_str, cls.dtypes, cls.__name__)
for x in get_args(array_type)
]
out = [_make_array(x, dim_str, cls) for x in get_args(array_type)]
out = tuple(x for x in out if x is not _not_made)
if len(out) == 0:
raise ValueError("Invalid jaxtyping type annotation.")
Expand All @@ -666,7 +675,7 @@ def __getitem__(cls, item: tuple[Any, str]):
else:
out = Union[out]
else:
out = _make_array(array_type, dim_str, cls.dtypes, cls.__name__)
out = _make_array(array_type, dim_str, cls)
if out is _not_made:
raise ValueError("Invalid jaxtyping type annotation.")
return out
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "jaxtyping"
version = "0.2.37"
version = "0.2.38"
description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays."
readme = "README.md"
requires-python =">=3.10"
Expand Down
36 changes: 23 additions & 13 deletions test/test_serialisation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

import cloudpickle
import numpy as np

Expand All @@ -7,19 +9,27 @@
except ImportError:
torch = None

from jaxtyping import AbstractArray, Array, Shaped
from jaxtyping import AbstractArray, Array, Float, Shaped


def test_pickle():
x = cloudpickle.dumps(Shaped[Array, ""])
cloudpickle.loads(x)

y = cloudpickle.dumps(AbstractArray)
cloudpickle.loads(y)

z = cloudpickle.dumps(Shaped[np.ndarray, ""])
cloudpickle.loads(z)

if torch is not None:
w = cloudpickle.dumps(Shaped[torch.Tensor, ""])
cloudpickle.loads(w)
for p in (pickle, cloudpickle):
x = p.dumps(Shaped[Array, ""])
y = p.loads(x)
assert y.dtype is Shaped
assert y.dim_str == ""

x = p.dumps(AbstractArray)
y = p.loads(x)
assert y is AbstractArray

x = p.dumps(Shaped[np.ndarray, "3 4 hi"])
y = p.loads(x)
assert y.dtype is Shaped
assert y.dim_str == "3 4 hi"

if torch is not None:
x = p.dumps(Float[torch.Tensor, "batch length"])
y = p.loads(x)
assert y.dtype is Float
assert y.dim_str == "batch length"