Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Next verify the tests all pass:
```bash
pip install -r test/requirements.txt
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pytest
pytest ./test
```

Then push your changes back to your fork of the repository:
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Type annotations **and runtime type-checking** for:

1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, and TensorFlow!)*
1. shape and dtype of [JAX](https://github.com/google/jax), [NumPy](https://github.com/numpy/numpy), [MLX](https://github.com/ml-explore/mlx), [PyTorch](https://github.com/pytorch/pytorch), [Tensorflow](https://github.com/tensorflow/tensorflow) arrays/tensors.
2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).


Expand All @@ -11,6 +11,7 @@ Type annotations **and runtime type-checking** for:
from jaxtyping import Array, Float, PyTree

# Accepts floating-point 2D arrays with matching axes
# Feel free to replace 'Array' by torch.Tensor, np.ndarray, tf.Tensor or mx.array.
def matrix_multiply(x: Float[Array, "dim1 dim2"],
y: Float[Array, "dim2 dim3"]
) -> Float[Array, "dim1 dim3"]:
Expand Down
3 changes: 2 additions & 1 deletion docs/api/array.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one
np.ndarray
torch.Tensor
tf.Tensor
mx.array
```
That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow.
That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow + MLX.

Some other types are also supported here:

Expand Down
5 changes: 2 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

jaxtyping is a library providing type annotations **and runtime type-checking** for:

1. shape and dtype of [JAX](https://github.com/google/jax) arrays;
1. shape and dtype of [JAX](https://github.com/google/jax), [NumPy](https://github.com/numpy/numpy), [MLX](https://github.com/ml-explore/mlx), [PyTorch](https://github.com/pytorch/pytorch), [Tensorflow](https://github.com/tensorflow/tensorflow) arrays/tensors.
2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).

*(Now also supports PyTorch, NumPy, and TensorFlow!)*

## Installation

Expand All @@ -15,7 +14,7 @@ pip install jaxtyping

Requires Python 3.10+.

JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.
JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/MLX/etc.

The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments).

Expand Down
9 changes: 2 additions & 7 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,15 +617,10 @@ def _make_array(x, dim_str, dtype):

if type(out) is tuple:
array_type, name, dtypes, dims, index_variadic, dim_str = out
metaclass = (
_make_metaclass(type)
if array_type is Any
else _make_metaclass(type(array_type))
)

metaclass = _make_metaclass(type)
out = metaclass(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we skip creating the metaclass too? And have this line just be type(...

name,
(AbstractArray,) if array_type is Any else (array_type, AbstractArray),
(AbstractArray,),
dict(
dtype=dtype,
array_type=array_type,
Expand Down
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pytest
pytest-asyncio
tensorflow
typeguard<3
mlx
13 changes: 5 additions & 8 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
except ImportError:
torch = None

try:
import mlx.core as mx
except ImportError:
mx = None

from jaxtyping import (
AbstractArray,
AbstractDtype,
Expand Down Expand Up @@ -598,14 +603,6 @@ def test_arraylike(typecheck, getkey):
)


def test_subclass():
assert issubclass(Float[Array, ""], Array)
assert issubclass(Float[np.ndarray, ""], np.ndarray)

if torch is not None:
assert issubclass(Float[torch.Tensor, ""], torch.Tensor)


def test_ignored_names():
x = Float[np.ndarray, "foo=4"]

Expand Down
26 changes: 26 additions & 0 deletions test/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from .helpers import assert_no_garbage, ParamError, ReturnError


try:
import mlx.core as mx
except ImportError:
mx = None


class M(metaclass=abc.ABCMeta):
@jaxtyped(typechecker=None)
def f(self): ...
Expand Down Expand Up @@ -128,6 +134,26 @@ def f(x: int, y: int = 1) -> Float[Array, "x {y}"]:
f(1, 5)


@pytest.mark.skipif(mx is None, reason="MLX is not installed")
def test_mlx_decorator(jaxtyp, typecheck):
@jaxtyp(typecheck)
def hello(x: Float[mx.array, "8 16"]):
pass

hello(mx.zeros((8, 16), dtype=mx.float32))

with pytest.raises(ParamError):
hello(mx.zeros((8, 14), dtype=mx.float32))

with pytest.raises(ParamError):
import numpy as np

hello(np.zeros((8, 16), dtype=np.float32))

with pytest.raises(ParamError):
hello(mx.zeros((8, 16), dtype=mx.int32))


class _GlobalFoo:
pass

Expand Down