diff --git a/README.md b/README.md index 671f8a6..2b4676e 100644 --- a/README.md +++ b/README.md @@ -2,15 +2,15 @@ Type annotations **and runtime type-checking** for: -1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, and TensorFlow!)* +1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)* 2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html). - **For example:** ```python from jaxtyping import Array, Float, PyTree # Accepts floating-point 2D arrays with matching axes +# You can replace `Array` with `torch.Tensor` etc. def matrix_multiply(x: Float[Array, "dim1 dim2"], y: Float[Array, "dim2 dim3"] ) -> Float[Array, "dim1 dim3"]: diff --git a/docs/api/advanced-features.md b/docs/api/advanced-features.md index 3c6dc5a..6c715c4 100644 --- a/docs/api/advanced-features.md +++ b/docs/api/advanced-features.md @@ -3,9 +3,8 @@ ## Creating your own dtypes ::: jaxtyping.AbstractDtype - selection: - members: - false + options: + members: [] ::: jaxtyping.make_numpy_struct_dtype diff --git a/docs/api/array.md b/docs/api/array.md index 4f170ab..b3d1eaf 100644 --- a/docs/api/array.md +++ b/docs/api/array.md @@ -90,8 +90,9 @@ jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one np.ndarray torch.Tensor tf.Tensor +mx.array ``` -That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow. +That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow + MLX. Some other types are also supported here: diff --git a/docs/api/pytree.md b/docs/api/pytree.md index fa864f0..d3a40ad 100644 --- a/docs/api/pytree.md +++ b/docs/api/pytree.md @@ -1,13 +1,14 @@ # PyTree annotations :::jaxtyping.PyTree - selection: - members: - false + options: + members: [] --- :::jaxtyping.PyTreeDef + options: + members: [] --- diff --git a/docs/index.md b/docs/index.md index 5c3f06d..57f6e74 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,11 +2,9 @@ jaxtyping is a library providing type annotations **and runtime type-checking** for: -1. shape and dtype of [JAX](https://github.com/google/jax) arrays; +1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)* 2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html). - *(Now also supports PyTorch, NumPy, and TensorFlow!)* - ## Installation ```bash @@ -25,6 +23,7 @@ The annotations provided by jaxtyping are compatible with runtime type-checking from jaxtyping import Array, Float, PyTree # Accepts floating-point 2D arrays with matching axes +# You can replace `Array` with `torch.Tensor` etc. def matrix_multiply(x: Float[Array, "dim1 dim2"], y: Float[Array, "dim2 dim3"] ) -> Float[Array, "dim1 dim3"]: diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 23b2733..0000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -mkdocs==1.3.0 # Main documentation generator. -mkdocs-material==7.3.6 # Theme -pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX. -mkdocstrings==0.17.0 # Autogenerate documentation from docstrings. -mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages. -pytkdocs_tweaks==0.0.8 # Tweaks mkdocstrings to improve various aspects -mkdocs_include_exclude_files==0.0.1 # Tweak which files are included/excluded -jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings. -pygments==2.14.0 -mkdocs-autorefs==1.0.1 -mkdocs-material-extensions==1.3.1 - -# Dependencies of jaxtyping itself. -# Always use most up-to-date versions. -jax[cpu] diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index 7ccd02e..559d944 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -19,7 +19,6 @@ import functools as ft import importlib.metadata -import importlib.util import typing import warnings from typing import TypeAlias, Union @@ -145,102 +144,78 @@ UInt64 as UInt64, ) - # But crucially, does not actually import jax at all. We do that dynamically in - # __getattr__ if required. See #178. - if importlib.util.find_spec("jax") is not None: + if hasattr(typing, "GENERATING_DOCUMENTATION"): - @ft.cache - def __getattr__(item): - if item == "Array": - if getattr(typing, "GENERATING_DOCUMENTATION", False): + class Array: + pass - class Array: - pass + Array.__module__ = "builtins" + Array.__qualname__ = "Array" - Array.__module__ = "builtins" - Array.__qualname__ = "Array" - return Array - else: - import jax + class ArrayLike: + pass - return jax.Array - elif item == "ArrayLike": - if getattr(typing, "GENERATING_DOCUMENTATION", False): + ArrayLike.__module__ = "builtins" + ArrayLike.__qualname__ = "ArrayLike" - class ArrayLike: - pass + class PRNGKeyArray: + pass - ArrayLike.__module__ = "builtins" - ArrayLike.__qualname__ = "ArrayLike" - return ArrayLike - else: - import jax.typing + PRNGKeyArray.__module__ = "builtins" + PRNGKeyArray.__qualname__ = "PRNGKeyArray" - return jax.typing.ArrayLike - elif item == "PRNGKeyArray": - if getattr(typing, "GENERATING_DOCUMENTATION", False): + from ._pytree_type import PyTree as PyTree - class PRNGKeyArray: - pass + class PyTreeDef: + """Alias for `jax.tree_util.PyTreeDef`, which is the type of the + return from `jax.tree_util.tree_structure(...)`. + """ - PRNGKeyArray.__module__ = "builtins" - PRNGKeyArray.__qualname__ = "PRNGKeyArray" - return PRNGKeyArray - else: - # New-style `jax.random.key` have scalar shape and dtype `key`. - # Old-style `jax.random.PRNGKey` have shape `(2,)` and dtype - # `uint32`. - import jax + if typing.GENERATING_DOCUMENTATION != "jaxtyping": + # Equinox etc. docs get just `PyTreeDef`. + # jaxtyping docs get `jaxtyping.PyTreeDef`. + PyTreeDef.__qualname__ = "PyTreeDef" + PyTreeDef.__module__ = "builtins" - return Union[Key[jax.Array, ""], UInt32[jax.Array, "2"]] - elif item == "DTypeLike": - import jax.typing + @ft.cache + def __getattr__(item): + if item == "Array": + import jax - return jax.typing.DTypeLike - elif item == "Scalar": - import jax + return jax.Array + elif item == "ArrayLike": + import jax.typing - return Shaped[jax.Array, ""] - elif item == "ScalarLike": - from . import ArrayLike + return jax.typing.ArrayLike + elif item == "PRNGKeyArray": + # New-style `jax.random.key` have scalar shape and dtype `key`. + # Old-style `jax.random.PRNGKey` have shape `(2,)` and dtype + # `uint32`. + import jax - return Shaped[ArrayLike, ""] - elif item == "PyTree": - from ._pytree_type import PyTree + return Union[Key[jax.Array, ""], UInt32[jax.Array, "2"]] + elif item == "DTypeLike": + import jax.typing - return PyTree - elif item == "PyTreeDef": - if hasattr(typing, "GENERATING_DOCUMENTATION"): - # Most parts of the Equinox ecosystem have - # `typing.GENERATING_DOCUMENTATION = True` when generating - # documentation, to add whatever shims are necessary to get pretty - # docs. E.g. to have type annotations appear as just `PyTree`, not - # `jaxtyping.PyTree`. - # - # As jaxtyping actually wants things to appear as e.g. - # `jaxtyping.PyTree`, rather than just `PyTree`, then it sets - # `typing.GENERATING_DOCUMENTATION = False`, to disable these shims. - # - # Here we do only a `hasattr` check, as we want to get this version - # of `PyTreeDef` in both the jaxtyping and the Equinox(/etc.) docs. + return jax.typing.DTypeLike + elif item == "Scalar": + import jax - class PyTreeDef: - """Alias for `jax.tree_util.PyTreeDef`, which is the type of the - return from `jax.tree_util.tree_structure(...)`. - """ + return Shaped[jax.Array, ""] + elif item == "ScalarLike": + from . import ArrayLike - if typing.GENERATING_DOCUMENTATION: - # Equinox etc. docs get just `PyTreeDef`. - # jaxtyping docs get `jaxtyping.PyTreeDef`. - PyTreeDef.__qualname__ = "PyTreeDef" - PyTreeDef.__module__ = "builtins" - return PyTreeDef - else: - import jax.tree_util + return Shaped[ArrayLike, ""] + elif item == "PyTree": + from ._pytree_type import PyTree - return jax.tree_util.PyTreeDef - else: - raise AttributeError(f"module jaxtyping has no attribute {item!r}") + return PyTree + elif item == "PyTreeDef": + import jax.tree_util + + return jax.tree_util.PyTreeDef + else: + raise AttributeError(f"module jaxtyping has no attribute {item!r}") check_equinox_version = True # easy-to-replace line with copybara diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index 0e33be3..382af27 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -318,34 +318,18 @@ def _check_shape( assert False +def _return_abstractarray(): + return AbstractArray + + def _pickle_array_annotation(x: type["AbstractArray"]): - return x.dtype.__getitem__, ((x.array_type, x.dim_str),) + if x is AbstractArray: + return _return_abstractarray, () + else: + return x.dtype.__getitem__, ((x.array_type, x.dim_str),) -@ft.lru_cache(maxsize=None) -def _make_metaclass(base_metaclass): - class MetaAbstractArray(_MetaAbstractArray, base_metaclass): - # We have to use identity-based eq/hash behaviour. The reason for this is that - # when deserializing using cloudpickle (very common, it seems), that cloudpickle - # will actually attempt to put a partially constructed class in a dictionary. - # So if we start accessing `cls.index_variadic` and the like here, then that - # explodes. - # See - # https://github.com/patrick-kidger/jaxtyping/issues/198 - # https://github.com/patrick-kidger/jaxtyping/issues/261 - # - # This does mean that if you want to compare two array annotations for equality - # (e.g. this happens in jaxtyping's tests as part of checking correctness) then - # a custom equality function must be used -- we can't put it here. - def __eq__(cls, other): - return cls is other - - def __hash__(cls): - return id(cls) - - copyreg.pickle(MetaAbstractArray, _pickle_array_annotation) - - return MetaAbstractArray +copyreg.pickle(_MetaAbstractArray, _pickle_array_annotation) def _check_scalar(dtype, dtypes, dims): @@ -617,15 +601,10 @@ def _make_array(x, dim_str, dtype): if type(out) is tuple: array_type, name, dtypes, dims, index_variadic, dim_str = out - metaclass = ( - _make_metaclass(type) - if array_type is Any - else _make_metaclass(type(array_type)) - ) - out = metaclass( + out = _MetaAbstractArray( name, - (AbstractArray,) if array_type is Any else (array_type, AbstractArray), + (AbstractArray,), dict( dtype=dtype, array_type=array_type, @@ -635,10 +614,10 @@ def _make_array(x, dim_str, dtype): index_variadic=index_variadic, ), ) - if getattr(typing, "GENERATING_DOCUMENTATION", False): - out.__module__ = "builtins" - else: + if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}: out.__module__ = "jaxtyping" + else: + out.__module__ = "builtins" return out @@ -767,10 +746,10 @@ class _Cls(AbstractDtype): _Cls.__name__ = name _Cls.__qualname__ = name - if getattr(typing, "GENERATING_DOCUMENTATION", False): - _Cls.__module__ = "builtins" - else: + if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}: _Cls.__module__ = "jaxtyping" + else: + _Cls.__module__ = "builtins" return _Cls diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index 5e0a517..85fc072 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -19,7 +19,6 @@ import dataclasses import functools as ft -import importlib.util import inspect import itertools as it import sys @@ -29,7 +28,6 @@ from typing import ( Any, get_args, - get_origin, get_type_hints, NoReturn, overload, @@ -90,7 +88,8 @@ def jaxtyped( def jaxtyped(fn=_sentinel, *, typechecker=_sentinel): """Decorate a function with this to perform runtime type-checking of its arguments - and return value. Decorate a dataclass to perform type-checking of its attributes. + and return value. Decorate a dataclass to perform type-checking of its `__init__` + method. !!! Example @@ -228,9 +227,10 @@ def f(...): ... if _tb_flag: try: import jax._src.traceback_util as traceback_util - traceback_util.register_exclusion(__file__) - except: + except Exception: pass + else: + traceback_util.register_exclusion(__file__) _tb_flag = False # First handle the `jaxtyped("context")` usage, which is a special case. @@ -288,34 +288,7 @@ def f(...): ... return ft.partial(jaxtyped, typechecker=typechecker) elif inspect.isclass(fn): if dataclasses.is_dataclass(fn) and typechecker is not None: - # This does not check that the arguments passed to `__init__` match the - # type annotations. There may be a custom user `__init__`, or a - # dataclass-generated `__init__` used alongside - # `equinox.field(converter=...)` - - init = fn.__init__ - - @ft.wraps(init) - def __init__(self, *args, **kwargs): - __tracebackhide__ = True - init(self, *args, **kwargs) - # `fn.__init__` is late-binding to the `__init__` function that - # we're in now. (Or to someone else's monkey-patch.) Either way, - # this checks that we're in the "top-level" `__init__`, and not one - # that is being called via `super()`. We don't want to trigger too - # early, before all fields have been assigned. - # - # We're not checking `if self.__class__ is fn` because Equinox - # replaces the with a defrozen version of itself during `__init__`, - # so the check wouldn't trigger. - # - # We're not doing this check by adding it to the end of the - # metaclass `__call__`, because Python doesn't allow you - # monkey-patch metaclasses. - if self.__class__.__init__ is fn.__init__: - _check_dataclass_annotations(self, typechecker) - - fn.__init__ = __init__ + fn.__init__ = jaxtyped(fn.__init__, typechecker=typechecker) return fn # It'd be lovely if we could handle arbitrary descriptors, and not just the builtin # ones. Unfortunately that means returning a class instance with a __get__ method, @@ -573,54 +546,6 @@ def __exit__(self, exc_type, exc_value, exc_tb): pop_shape_memo() -def _check_dataclass_annotations(self, typechecker): - """Creates and calls a function that checks the attributes of `self` - - `self` should be a dataclass instance. `typechecker` should be e.g. - `beartype.beartype` or `typeguard.typechecked`. - """ - parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] - values = {} - for field in dataclasses.fields(self): - annotation = field.type - if isinstance(annotation, str): - # Don't check stringified annotations. These are basically impossible to - # resolve correctly, so just skip them. - continue - if get_origin(annotation) is type: - args = get_args(annotation) - if len(args) == 1 and isinstance(args[0], str): - # We also special-case this one kind of partially-stringified type - # annotation, so as to support Equinox =0.1.3"] entry-points = {pytest11 = {jaxtyping = "jaxtyping._pytest_plugin"}} +[project.optional-dependencies] +docs = [ + "hippogriffe==0.1.0", + "mkdocs==1.6.1", + "mkdocs-include-exclude-files==0.1.0", + "mkdocs-ipynb==0.1.0", + "mkdocs-material==9.6.7", + "mkdocstrings[python]==0.28.3", + "pymdown-extensions==10.14.3", +] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/test/import_hook_tester.py b/test/import_hook_tester.py index abad4bb..122bf1b 100644 --- a/test/import_hook_tester.py +++ b/test/import_hook_tester.py @@ -69,14 +69,17 @@ class Mod2(eqx.Module): Mod2(1) # This will fail unless we run typechecking after conversion -class BadMod2(eqx.Module): - a: jnp.ndarray = eqx.field(converter=lambda x: x) +# This silently passes -- the untyped `lambda x: x` launders the value through. +# No easy way to tackle this. That's okay. +# class BadMod2(eqx.Module): +# a: jnp.ndarray = eqx.field(converter=lambda x: x) -with pytest.raises(ParamError): - BadMod2(1) -with pytest.raises(ParamError): - BadMod2("asdf") + +# with pytest.raises(ParamError): +# BadMod2(1) +# with pytest.raises(ParamError): +# BadMod2("asdf") # Custom `__init__`, no converter @@ -220,15 +223,13 @@ class Foo: class Bar(eqx.Module): x: type[Foo] y: "type[Foo]" - # Note that this is the *only* kind of partially-stringified type annotation that - # is supported. This is for compatibility with older Equinox versions. - z: type["Foo"] + # Partially-stringified hints not tested; not supported. -Bar(Foo, Foo, Foo) +Bar(Foo, Foo) with pytest.raises(ParamError): - Bar(1, Foo, Foo) + Bar(1, Foo) # diff --git a/test/requirements.txt b/test/requirements.txt index 44fc63b..c0b711b 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -8,3 +8,4 @@ pytest pytest-asyncio tensorflow typeguard<3 +mlx diff --git a/test/test_array.py b/test/test_array.py index 456b1e0..7fd036e 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -530,7 +530,7 @@ def test_deferred_symbolic_dataclass(typecheck): @dc.dataclass class A: value: int - array: Float[Array, " {self.value}"] + array: Float[Array, " {value}"] A(3, jnp.zeros(3)) @@ -602,14 +602,6 @@ def test_arraylike(typecheck, getkey): ) -def test_subclass(): - assert issubclass(Float[Array, ""], Array) - assert issubclass(Float[np.ndarray, ""], np.ndarray) - - if torch is not None: - assert issubclass(Float[torch.Tensor, ""], torch.Tensor) - - def test_ignored_names(): x = Float[np.ndarray, "foo=4"] diff --git a/test/test_decorator.py b/test/test_decorator.py index 6fe32ab..ea20467 100644 --- a/test/test_decorator.py +++ b/test/test_decorator.py @@ -3,6 +3,7 @@ import sys from typing import no_type_check +import equinox as eqx import jax.numpy as jnp import jax.random as jr import pytest @@ -254,3 +255,36 @@ class _Obj: x: int _Obj(x=5) + + +def test_equinox_converter(typecheck): + def _typed_str(x: int) -> str: + return str(x) + + @jaxtyped(typechecker=typecheck) + class X(eqx.Module): + x: str = eqx.field(converter=_typed_str) + + X(1) + with pytest.raises(ParamError): + X("1") + + +def test_mlx(jaxtyp, typecheck): + import mlx.core as mx + import numpy as np + + @jaxtyp(typecheck) + def hello(x: Float[mx.array, "8 16"]): + pass + + hello(mx.zeros((8, 16), dtype=mx.float32)) + + with pytest.raises(ParamError): + hello(mx.zeros((8, 14), dtype=mx.float32)) + + with pytest.raises(ParamError): + hello(np.zeros((8, 16), dtype=np.float32)) + + with pytest.raises(ParamError): + hello(mx.zeros((8, 16), dtype=mx.int32)) diff --git a/test/test_messages.py b/test/test_messages.py index 7987314..a00ccd2 100644 --- a/test/test_messages.py +++ b/test/test_messages.py @@ -56,7 +56,7 @@ def f(x: PyTree[Any, " T"], y: PyTree[Any, " S"]) -> PyTree[Any, "T S"]: "Type-check error whilst checking the return value of .*..f", r"Called with parameters: {'x': \(1, 2\), 'y': {'a': 1}}", "Actual value: 'foo'", - r"Expected type: PyTree\[Any, \"T S\"\].", + r"Expected type: PyTree\[Any, 'T S'\].", ( "The current values for each jaxtyping PyTree structure annotation are as " "follows." @@ -69,7 +69,7 @@ def f(x: PyTree[Any, " T"], y: PyTree[Any, " S"]) -> PyTree[Any, "T S"]: f(x, y=y) -def test_dataclass_attribute(typecheck): +def test_dataclass_init(typecheck): @jaxtyped(typechecker=typecheck) class M(eqx.Module): x: Float[Array, " *foo"] @@ -88,8 +88,8 @@ class M(eqx.Module): r"'y': \(1, \(3, 4\)\), 'z': 'not-an-int'}" ), ( - r"Parameter annotations: \(self: Any, x: Float\[Array, '\*foo'\], " - r"y: PyTree\[Any, \"T\"\], z: int\)." + r"Parameter annotations: \(self, x: Float\[Array, '\*foo'\], " + r"y: PyTree\[Any, 'T'\], z: int\)." ), "The current values for each jaxtyping axis annotation are as follows.", r"foo=\(2, 3\)", diff --git a/test/test_pytree.py b/test/test_pytree.py index 532769f..cefc0ba 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -17,6 +17,7 @@ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +from collections.abc import Callable from typing import NamedTuple, Tuple, Union import equinox as eqx @@ -24,6 +25,7 @@ import jax.numpy as jnp import jax.random as jr import pytest +import wadler_lindig as wl import jaxtyping from jaxtyping import AnnotationError, Array, Float, PyTree @@ -341,3 +343,53 @@ def f(x: PyTree[PyTree[Float[Array, "?foo"], " S"], " T"]): x1 = jr.normal(getkey(), (2,)) with pytest.raises(AnnotationError, match="ambiguous which PyTree"): f(x1) + + +def test_name(): + assert PyTree.__name__ == "PyTree" + assert PyTree[int].__name__ == "PyTree[int]" + assert PyTree[int, "foo"].__name__ == "PyTree[int, 'foo']" + assert PyTree[PyTree[str], "foo"].__name__ == "PyTree[PyTree[str], 'foo']" + assert ( + PyTree[PyTree[str, "bar"], "foo"].__name__ + == "PyTree[PyTree[str, 'bar'], 'foo']" + ) + assert PyTree[PyTree[str, "bar"]].__name__ == "PyTree[PyTree[str, 'bar']]" + assert ( + PyTree[None | Callable[[PyTree[int, " T"]], str]].__name__ + == "PyTree[None | Callable[[PyTree[int, 'T']], str]]" + ) + + +def test_pdoc(): + assert wl.pformat(PyTree) == "PyTree" + assert wl.pformat(PyTree[int]) == "PyTree[int]" + assert wl.pformat(PyTree[int, "foo"]) == "PyTree[int, 'foo']" + assert wl.pformat(PyTree[PyTree[str], "foo"]) == "PyTree[PyTree[str], 'foo']" + assert ( + wl.pformat(PyTree[PyTree[str, "bar"], "foo"]) + == "PyTree[PyTree[str, 'bar'], 'foo']" + ) + assert wl.pformat(PyTree[PyTree[str, "bar"]]) == "PyTree[PyTree[str, 'bar']]" + assert ( + wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]]) + == "PyTree[None | Callable[[PyTree[int, 'T']], str]]" + ) + expected = """ +PyTree[ + None + | Callable[ + [ + PyTree[ + int, + 'T' + ] + ], + str + ] +] + """.strip() + assert ( + wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]], width=2).strip() + == expected + )