Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

Type annotations **and runtime type-checking** for:

1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, and TensorFlow!)*
1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).


**For example:**
```python
from jaxtyping import Array, Float, PyTree

# Accepts floating-point 2D arrays with matching axes
# You can replace `Array` with `torch.Tensor` etc.
def matrix_multiply(x: Float[Array, "dim1 dim2"],
y: Float[Array, "dim2 dim3"]
) -> Float[Array, "dim1 dim3"]:
Expand Down
5 changes: 2 additions & 3 deletions docs/api/advanced-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
## Creating your own dtypes

::: jaxtyping.AbstractDtype
selection:
members:
false
options:
members: []

::: jaxtyping.make_numpy_struct_dtype

Expand Down
3 changes: 2 additions & 1 deletion docs/api/array.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one
np.ndarray
torch.Tensor
tf.Tensor
mx.array
```
That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow.
That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow + MLX.

Some other types are also supported here:

Expand Down
7 changes: 4 additions & 3 deletions docs/api/pytree.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# PyTree annotations

:::jaxtyping.PyTree
selection:
members:
false
options:
members: []

---

:::jaxtyping.PyTreeDef
options:
members: []

---

Expand Down
5 changes: 2 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

jaxtyping is a library providing type annotations **and runtime type-checking** for:

1. shape and dtype of [JAX](https://github.com/google/jax) arrays;
1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).

*(Now also supports PyTorch, NumPy, and TensorFlow!)*

## Installation

```bash
Expand All @@ -25,6 +23,7 @@ The annotations provided by jaxtyping are compatible with runtime type-checking
from jaxtyping import Array, Float, PyTree

# Accepts floating-point 2D arrays with matching axes
# You can replace `Array` with `torch.Tensor` etc.
def matrix_multiply(x: Float[Array, "dim1 dim2"],
y: Float[Array, "dim2 dim3"]
) -> Float[Array, "dim1 dim3"]:
Expand Down
15 changes: 0 additions & 15 deletions docs/requirements.txt

This file was deleted.

135 changes: 55 additions & 80 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import functools as ft
import importlib.metadata
import importlib.util
import typing
import warnings
from typing import TypeAlias, Union
Expand Down Expand Up @@ -145,102 +144,78 @@
UInt64 as UInt64,
)

# But crucially, does not actually import jax at all. We do that dynamically in
# __getattr__ if required. See #178.
if importlib.util.find_spec("jax") is not None:
if hasattr(typing, "GENERATING_DOCUMENTATION"):

@ft.cache
def __getattr__(item):
if item == "Array":
if getattr(typing, "GENERATING_DOCUMENTATION", False):
class Array:
pass

class Array:
pass
Array.__module__ = "builtins"
Array.__qualname__ = "Array"

Array.__module__ = "builtins"
Array.__qualname__ = "Array"
return Array
else:
import jax
class ArrayLike:
pass

return jax.Array
elif item == "ArrayLike":
if getattr(typing, "GENERATING_DOCUMENTATION", False):
ArrayLike.__module__ = "builtins"
ArrayLike.__qualname__ = "ArrayLike"

class ArrayLike:
pass
class PRNGKeyArray:
pass

ArrayLike.__module__ = "builtins"
ArrayLike.__qualname__ = "ArrayLike"
return ArrayLike
else:
import jax.typing
PRNGKeyArray.__module__ = "builtins"
PRNGKeyArray.__qualname__ = "PRNGKeyArray"

return jax.typing.ArrayLike
elif item == "PRNGKeyArray":
if getattr(typing, "GENERATING_DOCUMENTATION", False):
from ._pytree_type import PyTree as PyTree

class PRNGKeyArray:
pass
class PyTreeDef:
"""Alias for `jax.tree_util.PyTreeDef`, which is the type of the
return from `jax.tree_util.tree_structure(...)`.
"""

PRNGKeyArray.__module__ = "builtins"
PRNGKeyArray.__qualname__ = "PRNGKeyArray"
return PRNGKeyArray
else:
# New-style `jax.random.key` have scalar shape and dtype `key<foo>`.
# Old-style `jax.random.PRNGKey` have shape `(2,)` and dtype
# `uint32`.
import jax
if typing.GENERATING_DOCUMENTATION != "jaxtyping":
# Equinox etc. docs get just `PyTreeDef`.
# jaxtyping docs get `jaxtyping.PyTreeDef`.
PyTreeDef.__qualname__ = "PyTreeDef"
PyTreeDef.__module__ = "builtins"

return Union[Key[jax.Array, ""], UInt32[jax.Array, "2"]]
elif item == "DTypeLike":
import jax.typing
@ft.cache
def __getattr__(item):
if item == "Array":
import jax

return jax.typing.DTypeLike
elif item == "Scalar":
import jax
return jax.Array
elif item == "ArrayLike":
import jax.typing

return Shaped[jax.Array, ""]
elif item == "ScalarLike":
from . import ArrayLike
return jax.typing.ArrayLike
elif item == "PRNGKeyArray":
# New-style `jax.random.key` have scalar shape and dtype `key<foo>`.
# Old-style `jax.random.PRNGKey` have shape `(2,)` and dtype
# `uint32`.
import jax

return Shaped[ArrayLike, ""]
elif item == "PyTree":
from ._pytree_type import PyTree
return Union[Key[jax.Array, ""], UInt32[jax.Array, "2"]]
elif item == "DTypeLike":
import jax.typing

return PyTree
elif item == "PyTreeDef":
if hasattr(typing, "GENERATING_DOCUMENTATION"):
# Most parts of the Equinox ecosystem have
# `typing.GENERATING_DOCUMENTATION = True` when generating
# documentation, to add whatever shims are necessary to get pretty
# docs. E.g. to have type annotations appear as just `PyTree`, not
# `jaxtyping.PyTree`.
#
# As jaxtyping actually wants things to appear as e.g.
# `jaxtyping.PyTree`, rather than just `PyTree`, then it sets
# `typing.GENERATING_DOCUMENTATION = False`, to disable these shims.
#
# Here we do only a `hasattr` check, as we want to get this version
# of `PyTreeDef` in both the jaxtyping and the Equinox(/etc.) docs.
return jax.typing.DTypeLike
elif item == "Scalar":
import jax

class PyTreeDef:
"""Alias for `jax.tree_util.PyTreeDef`, which is the type of the
return from `jax.tree_util.tree_structure(...)`.
"""
return Shaped[jax.Array, ""]
elif item == "ScalarLike":
from . import ArrayLike

if typing.GENERATING_DOCUMENTATION:
# Equinox etc. docs get just `PyTreeDef`.
# jaxtyping docs get `jaxtyping.PyTreeDef`.
PyTreeDef.__qualname__ = "PyTreeDef"
PyTreeDef.__module__ = "builtins"
return PyTreeDef
else:
import jax.tree_util
return Shaped[ArrayLike, ""]
elif item == "PyTree":
from ._pytree_type import PyTree

return jax.tree_util.PyTreeDef
else:
raise AttributeError(f"module jaxtyping has no attribute {item!r}")
return PyTree
elif item == "PyTreeDef":
import jax.tree_util

return jax.tree_util.PyTreeDef
else:
raise AttributeError(f"module jaxtyping has no attribute {item!r}")


check_equinox_version = True # easy-to-replace line with copybara
Expand Down
55 changes: 17 additions & 38 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,34 +318,18 @@ def _check_shape(
assert False


def _return_abstractarray():
return AbstractArray


def _pickle_array_annotation(x: type["AbstractArray"]):
return x.dtype.__getitem__, ((x.array_type, x.dim_str),)
if x is AbstractArray:
return _return_abstractarray, ()
else:
return x.dtype.__getitem__, ((x.array_type, x.dim_str),)


@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
# We have to use identity-based eq/hash behaviour. The reason for this is that
# when deserializing using cloudpickle (very common, it seems), that cloudpickle
# will actually attempt to put a partially constructed class in a dictionary.
# So if we start accessing `cls.index_variadic` and the like here, then that
# explodes.
# See
# https://github.com/patrick-kidger/jaxtyping/issues/198
# https://github.com/patrick-kidger/jaxtyping/issues/261
#
# This does mean that if you want to compare two array annotations for equality
# (e.g. this happens in jaxtyping's tests as part of checking correctness) then
# a custom equality function must be used -- we can't put it here.
def __eq__(cls, other):
return cls is other

def __hash__(cls):
return id(cls)

copyreg.pickle(MetaAbstractArray, _pickle_array_annotation)

return MetaAbstractArray
copyreg.pickle(_MetaAbstractArray, _pickle_array_annotation)


def _check_scalar(dtype, dtypes, dims):
Expand Down Expand Up @@ -617,15 +601,10 @@ def _make_array(x, dim_str, dtype):

if type(out) is tuple:
array_type, name, dtypes, dims, index_variadic, dim_str = out
metaclass = (
_make_metaclass(type)
if array_type is Any
else _make_metaclass(type(array_type))
)

out = metaclass(
out = _MetaAbstractArray(
name,
(AbstractArray,) if array_type is Any else (array_type, AbstractArray),
(AbstractArray,),
dict(
dtype=dtype,
array_type=array_type,
Expand All @@ -635,10 +614,10 @@ def _make_array(x, dim_str, dtype):
index_variadic=index_variadic,
),
)
if getattr(typing, "GENERATING_DOCUMENTATION", False):
out.__module__ = "builtins"
else:
if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}:
out.__module__ = "jaxtyping"
else:
out.__module__ = "builtins"

return out

Expand Down Expand Up @@ -767,10 +746,10 @@ class _Cls(AbstractDtype):

_Cls.__name__ = name
_Cls.__qualname__ = name
if getattr(typing, "GENERATING_DOCUMENTATION", False):
_Cls.__module__ = "builtins"
else:
if getattr(typing, "GENERATING_DOCUMENTATION", "") in {"", "jaxtyping"}:
_Cls.__module__ = "jaxtyping"
else:
_Cls.__module__ = "builtins"
return _Cls


Expand Down
Loading