diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ab820e0..9fb930e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,7 +30,7 @@ Next verify the tests all pass: ```bash pip install -r test/requirements.txt pip install torch --extra-index-url https://download.pytorch.org/whl/cpu -pytest +pytest ./test ``` Then push your changes back to your fork of the repository: diff --git a/README.md b/README.md index 671f8a6..ae3e5e4 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Type annotations **and runtime type-checking** for: -1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, and TensorFlow!)* +1. shape and dtype of [JAX](https://github.com/google/jax), [NumPy](https://github.com/numpy/numpy), [MLX](https://github.com/ml-explore/mlx), [PyTorch](https://github.com/pytorch/pytorch), [Tensorflow](https://github.com/tensorflow/tensorflow) arrays/tensors. 2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html). @@ -11,6 +11,7 @@ Type annotations **and runtime type-checking** for: from jaxtyping import Array, Float, PyTree # Accepts floating-point 2D arrays with matching axes +# Feel free to replace 'Array' by torch.Tensor, np.ndarray, tf.Tensor or mx.array. def matrix_multiply(x: Float[Array, "dim1 dim2"], y: Float[Array, "dim2 dim3"] ) -> Float[Array, "dim1 dim3"]: diff --git a/docs/api/array.md b/docs/api/array.md index 3e97982..8cbb2bf 100644 --- a/docs/api/array.md +++ b/docs/api/array.md @@ -90,8 +90,9 @@ jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one np.ndarray torch.Tensor tf.Tensor +mx.array ``` -That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow. +That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow + MLX. Some other types are also supported here: diff --git a/docs/index.md b/docs/index.md index 5c3f06d..0dfc9b1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,10 +2,9 @@ jaxtyping is a library providing type annotations **and runtime type-checking** for: -1. shape and dtype of [JAX](https://github.com/google/jax) arrays; +1. shape and dtype of [JAX](https://github.com/google/jax), [NumPy](https://github.com/numpy/numpy), [MLX](https://github.com/ml-explore/mlx), [PyTorch](https://github.com/pytorch/pytorch), [Tensorflow](https://github.com/tensorflow/tensorflow) arrays/tensors. 2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html). - *(Now also supports PyTorch, NumPy, and TensorFlow!)* ## Installation @@ -15,7 +14,7 @@ pip install jaxtyping Requires Python 3.10+. -JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc. +JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/MLX/etc. The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments). diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index 3c8252a..ac5e675 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -617,15 +617,10 @@ def _make_array(x, dim_str, dtype): if type(out) is tuple: array_type, name, dtypes, dims, index_variadic, dim_str = out - metaclass = ( - _make_metaclass(type) - if array_type is Any - else _make_metaclass(type(array_type)) - ) - + metaclass = _make_metaclass(type) out = metaclass( name, - (AbstractArray,) if array_type is Any else (array_type, AbstractArray), + (AbstractArray,), dict( dtype=dtype, array_type=array_type, diff --git a/test/requirements.txt b/test/requirements.txt index 44fc63b..c0b711b 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -8,3 +8,4 @@ pytest pytest-asyncio tensorflow typeguard<3 +mlx diff --git a/test/test_array.py b/test/test_array.py index 76f35d7..29fb991 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -32,6 +32,11 @@ except ImportError: torch = None +try: + import mlx.core as mx +except ImportError: + mx = None + from jaxtyping import ( AbstractArray, AbstractDtype, @@ -598,14 +603,6 @@ def test_arraylike(typecheck, getkey): ) -def test_subclass(): - assert issubclass(Float[Array, ""], Array) - assert issubclass(Float[np.ndarray, ""], np.ndarray) - - if torch is not None: - assert issubclass(Float[torch.Tensor, ""], torch.Tensor) - - def test_ignored_names(): x = Float[np.ndarray, "foo=4"] diff --git a/test/test_decorator.py b/test/test_decorator.py index 6fe32ab..8f08f91 100644 --- a/test/test_decorator.py +++ b/test/test_decorator.py @@ -12,6 +12,12 @@ from .helpers import assert_no_garbage, ParamError, ReturnError +try: + import mlx.core as mx +except ImportError: + mx = None + + class M(metaclass=abc.ABCMeta): @jaxtyped(typechecker=None) def f(self): ... @@ -128,6 +134,26 @@ def f(x: int, y: int = 1) -> Float[Array, "x {y}"]: f(1, 5) +@pytest.mark.skipif(mx is None, reason="MLX is not installed") +def test_mlx_decorator(jaxtyp, typecheck): + @jaxtyp(typecheck) + def hello(x: Float[mx.array, "8 16"]): + pass + + hello(mx.zeros((8, 16), dtype=mx.float32)) + + with pytest.raises(ParamError): + hello(mx.zeros((8, 14), dtype=mx.float32)) + + with pytest.raises(ParamError): + import numpy as np + + hello(np.zeros((8, 16), dtype=np.float32)) + + with pytest.raises(ParamError): + hello(mx.zeros((8, 16), dtype=mx.int32)) + + class _GlobalFoo: pass