Skip to content

Commit 13eab3e

Browse files
Added docs and tests for MLX support
1 parent 2b45335 commit 13eab3e

File tree

5 files changed

+27
-6
lines changed

5 files changed

+27
-6
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
Type annotations **and runtime type-checking** for:
44

5-
1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, and TensorFlow!)*
5+
1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
66
2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).
77

8-
98
**For example:**
109
```python
1110
from jaxtyping import Array, Float, PyTree
1211

1312
# Accepts floating-point 2D arrays with matching axes
13+
# You can replace `Array` with `torch.Tensor` etc.
1414
def matrix_multiply(x: Float[Array, "dim1 dim2"],
1515
y: Float[Array, "dim2 dim3"]
1616
) -> Float[Array, "dim1 dim3"]:

docs/api/array.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one
9090
np.ndarray
9191
torch.Tensor
9292
tf.Tensor
93+
mx.array
9394
```
94-
That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow.
95+
That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow + MLX.
9596

9697
Some other types are also supported here:
9798

docs/index.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22

33
jaxtyping is a library providing type annotations **and runtime type-checking** for:
44

5-
1. shape and dtype of [JAX](https://github.com/google/jax) arrays;
5+
1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
66
2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).
77

8-
*(Now also supports PyTorch, NumPy, and TensorFlow!)*
9-
108
## Installation
119

1210
```bash
@@ -25,6 +23,7 @@ The annotations provided by jaxtyping are compatible with runtime type-checking
2523
from jaxtyping import Array, Float, PyTree
2624

2725
# Accepts floating-point 2D arrays with matching axes
26+
# You can replace `Array` with `torch.Tensor` etc.
2827
def matrix_multiply(x: Float[Array, "dim1 dim2"],
2928
y: Float[Array, "dim2 dim3"]
3029
) -> Float[Array, "dim1 dim3"]:

test/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ pytest
88
pytest-asyncio
99
tensorflow
1010
typeguard<3
11+
mlx

test/test_decorator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,23 @@ class X(eqx.Module):
268268
X(1)
269269
with pytest.raises(ParamError):
270270
X("1")
271+
272+
273+
def test_mlx(jaxtyp, typecheck):
274+
import mlx.core as mx
275+
import numpy as np
276+
277+
@jaxtyp(typecheck)
278+
def hello(x: Float[mx.array, "8 16"]):
279+
pass
280+
281+
hello(mx.zeros((8, 16), dtype=mx.float32))
282+
283+
with pytest.raises(ParamError):
284+
hello(mx.zeros((8, 14), dtype=mx.float32))
285+
286+
with pytest.raises(ParamError):
287+
hello(np.zeros((8, 16), dtype=np.float32))
288+
289+
with pytest.raises(ParamError):
290+
hello(mx.zeros((8, 16), dtype=mx.int32))

0 commit comments

Comments
 (0)