Skip to content

Commit f987a10

Browse files
jaxtyping annotations are no longer subclasses of their array type.
This has been a subtle point that gets pretty tricky as plenty of classes aren't really designed to be subclassed further, and some may even take steps to ensure this isn't the case (e.g. concrete Equinox modules, or MLX arrays). I *think* the only use-case for this is dispatch using plum -- I imagine we can probably find another way to make that happen. Either way, if you're reading this because your code just broke, and you are relying on the current subclassing behaviour, then please open an issue on jaxtyping.
1 parent 6297f52 commit f987a10

File tree

3 files changed

+12
-41
lines changed

3 files changed

+12
-41
lines changed

jaxtyping/_array_types.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -318,34 +318,18 @@ def _check_shape(
318318
assert False
319319

320320

321+
def _return_abstractarray():
322+
return AbstractArray
323+
324+
321325
def _pickle_array_annotation(x: type["AbstractArray"]):
322-
return x.dtype.__getitem__, ((x.array_type, x.dim_str),)
326+
if x is AbstractArray:
327+
return _return_abstractarray, ()
328+
else:
329+
return x.dtype.__getitem__, ((x.array_type, x.dim_str),)
323330

324331

325-
@ft.lru_cache(maxsize=None)
326-
def _make_metaclass(base_metaclass):
327-
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
328-
# We have to use identity-based eq/hash behaviour. The reason for this is that
329-
# when deserializing using cloudpickle (very common, it seems), that cloudpickle
330-
# will actually attempt to put a partially constructed class in a dictionary.
331-
# So if we start accessing `cls.index_variadic` and the like here, then that
332-
# explodes.
333-
# See
334-
# https://github.com/patrick-kidger/jaxtyping/issues/198
335-
# https://github.com/patrick-kidger/jaxtyping/issues/261
336-
#
337-
# This does mean that if you want to compare two array annotations for equality
338-
# (e.g. this happens in jaxtyping's tests as part of checking correctness) then
339-
# a custom equality function must be used -- we can't put it here.
340-
def __eq__(cls, other):
341-
return cls is other
342-
343-
def __hash__(cls):
344-
return id(cls)
345-
346-
copyreg.pickle(MetaAbstractArray, _pickle_array_annotation)
347-
348-
return MetaAbstractArray
332+
copyreg.pickle(_MetaAbstractArray, _pickle_array_annotation)
349333

350334

351335
def _check_scalar(dtype, dtypes, dims):
@@ -617,15 +601,10 @@ def _make_array(x, dim_str, dtype):
617601

618602
if type(out) is tuple:
619603
array_type, name, dtypes, dims, index_variadic, dim_str = out
620-
metaclass = (
621-
_make_metaclass(type)
622-
if array_type is Any
623-
else _make_metaclass(type(array_type))
624-
)
625604

626-
out = metaclass(
605+
out = _MetaAbstractArray(
627606
name,
628-
(AbstractArray,) if array_type is Any else (array_type, AbstractArray),
607+
(AbstractArray,),
629608
dict(
630609
dtype=dtype,
631610
array_type=array_type,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "jaxtyping"
3-
version = "0.2.38"
3+
version = "0.2.39"
44
description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays."
55
readme = "README.md"
66
requires-python =">=3.10"

test/test_array.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -602,14 +602,6 @@ def test_arraylike(typecheck, getkey):
602602
)
603603

604604

605-
def test_subclass():
606-
assert issubclass(Float[Array, ""], Array)
607-
assert issubclass(Float[np.ndarray, ""], np.ndarray)
608-
609-
if torch is not None:
610-
assert issubclass(Float[torch.Tensor, ""], torch.Tensor)
611-
612-
613605
def test_ignored_names():
614606
x = Float[np.ndarray, "foo=4"]
615607

0 commit comments

Comments
 (0)