diff --git a/README.md b/README.md
index 2b4676e..dc2829e 100644
--- a/README.md
+++ b/README.md
@@ -1,25 +1,17 @@
jaxtyping
-Type annotations **and runtime type-checking** for:
+A library providing type annotations **and runtime type-checking** for the shape and dtype of JAX/PyTorch/NumPy/MLX/TensorFlow arrays and tensors.
-1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
-2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).
+_The name 'jax'typing is now historical, we support all of the above and have no JAX dependency!_
-**For example:**
```python
-from jaxtyping import Array, Float, PyTree
+from jaxtyping import Float
+from torch import Tensor
# Accepts floating-point 2D arrays with matching axes
-# You can replace `Array` with `torch.Tensor` etc.
-def matrix_multiply(x: Float[Array, "dim1 dim2"],
- y: Float[Array, "dim2 dim3"]
- ) -> Float[Array, "dim1 dim3"]:
- ...
-
-def accepts_pytree_of_ints(x: PyTree[int]):
- ...
-
-def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
+def matrix_multiply(x: Float[Tensor, "dim1 dim2"],
+ y: Float[Tensor, "dim2 dim3"]
+ ) -> Float[Tensor, "dim1 dim3"]:
...
```
@@ -31,8 +23,6 @@ pip install jaxtyping
Requires Python 3.10+.
-JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.
-
The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments).
## Documentation
diff --git a/docs/api/advanced-features.md b/docs/api/advanced-features.md
index 6c715c4..e2d42f6 100644
--- a/docs/api/advanced-features.md
+++ b/docs/api/advanced-features.md
@@ -14,10 +14,16 @@
## Introspection
-If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.
+::: jaxtyping.AbstractArray
+ options:
+ members: []
+
+!!! info
+
+ If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.
-You can check for dtypes by doing `issubclass(x, AbstractDtype)`. For example, `issubclass(Float32, AbstractDtype)` will pass.
+ You can check for dtypes by doing `issubclass(x, AbstractDtype)`. For example, `issubclass(Float32, AbstractDtype)` will pass.
-You can check for arrays by doing `issubclass(x, AbstractArray)`. Here, `AbstractArray` is the base class for all shape-and-dtype specified arrays, e.g. it's a base class for `Float32[Array, "foo"]`.
+ You can check for arrays by doing `issubclass(x, AbstractArray)`., For example, `issubclass(Float32[jax.Array, "some shape"], AbstractArray)` will pass.
-You can check for pytrees by doing `issubclass(x, PyTree)`. For example, `issubclass(PyTree[int], PyTree)` will pass.
+ You can check for pytrees by doing `issubclass(x, PyTree)`. For example, `issubclass(PyTree[int], PyTree)` will pass.
diff --git a/docs/api/array.md b/docs/api/array.md
index b3d1eaf..f82a98d 100644
--- a/docs/api/array.md
+++ b/docs/api/array.md
@@ -1,6 +1,26 @@
# Array annotations
-The shape and dtypes of arrays can be annotated in the form `dtype[array, shape]`, such as `Float[Array, "batch channels"]`.
+The shape and dtypes of arrays can be annotated in the form `dtype[array, shape]`, such as
+`jaxtyping.Float[torch.Tensor, "batch channels"]`.
+
+## Dtype
+
+The dtype should be any one of (all imported from `jaxtyping`):
+
+- Any dtype at all: `Shaped`
+ - Boolean: `Bool`
+ - Any integer, unsigned integer, floating, or complex: `Num`
+ - Any floating or complex: `Inexact`
+ - Any floating point: `Float`
+ - Of particular precision: `BFloat16`, `Float16`, `Float32`, `Float64`
+ - Any complex: `Complex`
+ - Of particular precision: `Complex64`, `Complex128`
+ - Any integer or unsigned integer: `Integer`
+ - Any unsigned integer: `UInt`
+ - Of particular precision: `UInt2`, `UInt4`, `UInt8`, `UInt16`, `UInt32`, `UInt64`
+ - Any signed integer: `Int`
+ - Of particular precision: `Int2`, `Int4`, `Int8`, `Int16`, `Int32`, `Int64`
+ - Any floating, integer, or unsigned integer: `Real`.
## Shape
@@ -11,7 +31,7 @@ The shape should be a string of space-separated symbols, such as `"a b c d"`. Ea
- `int`: fixed-size axis, e.g. `"28 28"`.
- `str`: variable-size axis, e.g. `"channels"`.
- A symbolic expression in terms of other variable-size axes, e.g.
- `def remove_last(x: Float[Array, "dim"]) -> Float[Array, "dim-1"]`.
+ `def remove_last(x: Float[torch.Tensor, "dim"]) -> Float[torch.Tensor, "dim-1"]`.
Symbolic expressions must not use any spaces, otherwise each piece is treated as as a separate axis.
When calling a function, variable-size axes and symbolic axes will be matched up across all arguments and checked for consistency. (See [Runtime type checking](./runtime-type-checking.md).)
@@ -21,10 +41,10 @@ When calling a function, variable-size axes and symbolic axes will be matched up
In addition some modifiers can be applied:
- Prepend `*` to an axis to indicate that it can match multiple axes, e.g. `"*batch"` will match zero or more batch axes.
-- Prepend `#` to an axis to indicate that it can be that size *or* equal to one -- i.e. broadcasting is acceptable, e.g.
- `def add(x: Float[Array, "#foo"], y: Float[Array, "#foo"]) -> Float[Array, "#foo"]`.
+- Prepend `#` to an axis to indicate that it can be that size *or* equal to one – i.e. broadcasting is acceptable, e.g.
+ `def add(x: Float[torch.Tensor, "#foo"], y: Float[torch.Tensor, "#foo"]) -> Float[torch.Tensor, "#foo"]`.
- Prepend `_` to an axis to disable any runtime checking of that axis (so that it can be used just as documentation). This can also be used as just `_` on its own: e.g. `"b c _ _"`.
-- Documentation-only names (i.e. they're ignored by jaxtyping) can be handled by prepending a name followed by `=` e.g. `Float[Array, "rows=4 cols=3"]`.
+- Documentation-only names (i.e. they're ignored by jaxtyping) can be handled by prepending a name followed by `=` e.g. `Float[torch.Tensor, "rows=4 cols=3"]`.
- Prepend `?` to an axis to indicate that its size can vary within a PyTree structure. (See [PyTree annotations](./pytree.md).)
When using multiple modifiers, their order does not matter.
@@ -35,72 +55,40 @@ As a special case:
**Notes**
-- To denote a scalar shape use `""`, e.g. `Float[Array, ""]`.
-- To denote an arbitrary shape (and only check dtype) use `"..."`, e.g. `Float[Array, "..."]`.
+- To denote a scalar shape use `""`, e.g. `Float[torch.Tensor, ""]`.
+- To denote an arbitrary shape (and only check dtype) use `"..."`, e.g. `Float[torch.Tensor, "..."]`.
- You cannot have more than one use of multiple-axes, i.e. you can only use `...` or `*name` at most once in each array.
- A symbolic expression cannot be evaluated unless all of the axes sizes it refers to have already been processed. In practice this usually means that they should only be used in annotations for the return type, and only use axes declared in the arguments.
- Symbolic expressions are evaluated in two stages: they are first evaluated as f-strings using the arguments of the function, and second are evaluated using the processed axis sizes. The f-string evaluation means that they can use local variables by enclosing them with curly braces, e.g. `{variable}`, e.g.
```python
- def full(size: int, fill: float) -> Float[Array, "{size}"]:
+ def full(size: int, fill: float) -> Float[jax.Array, "{size}"]:
return jax.numpy.full((size,), fill)
class SomeClass:
some_value = 5
- def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]:
+ def full(self, fill: float) -> Float[jax.Array, "{self.some_value}+3"]:
return jax.numpy.full((self.some_value + 3,), fill)
```
-## Dtype
-
-The dtype should be any one of (all imported from `jaxtyping`):
+## Array
-- Any dtype at all: `Shaped`
- - Boolean: `Bool`
- - PRNG key: `Key`
- - Any integer, unsigned integer, floating, or complex: `Num`
- - Any floating or complex: `Inexact`
- - Any floating point: `Float`
- - Of particular precision: `BFloat16`, `Float16`, `Float32`, `Float64`
- - Any complex: `Complex`
- - Of particular precision: `Complex64`, `Complex128`
- - Any integer or unsigned intger: `Integer`
- - Any unsigned integer: `UInt`
- - Of particular precision: `UInt2`, `UInt4`, `UInt8`, `UInt16`, `UInt32`, `UInt64`
- - Any signed integer: `Int`
- - Of particular precision: `Int2`, `Int4`, `Int8`, `Int16`, `Int32`, `Int64`
- - Any floating, integer, or unsigned integer: `Real`.
+A variety of types are supported here:
-Unless you really want to force a particular precision, then for most applications you should probably allow any floating-point, any integer, etc. That is, use
-```python
-from jaxtyping import Array, Float
-Float[Array, "some_shape"]
-```
-rather than
-```python
-from jaxtyping import Array, Float32
-Float32[Array, "some_shape"]
-```
+**Arrays and Tensors:**
-## Array
+The following frameworks are supported:
-The array should typically be either one of:
```python
-jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one another
+jax.Array / jax.numpy.ndarray # these are both aliases of one another
np.ndarray
torch.Tensor
tf.Tensor
mx.array
```
-That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow + MLX.
-
-Some other types are also supported here:
+_Despite the now-historical name, 'jax'typing also supports NumPy + PyTorch + TensorFlow + MLX._
-**Unions:** these are unpacked. For example, `SomeDtype[Union[A, B], "some shape"]` is equivalent to `Union[SomeDtype[A, "some shape"], SomeDtype[B, "some shape"]]`. A common example of a union type here is `np.typing.ArrayLike`.
-
-**Any:** use `typing.Any` to check just the shape/dtype, but not the array type.
-
-**Duck-type arrays:** anything with `.shape` and `.dtype` attributes. For example,
+**Duck-type arrays:** anything with `.shape` and `.dtype` attributes. The shape should be a `tuple[int, ...]` and the dtype should be a `str`. For example,
```python
class MyDuckArray:
@property
@@ -121,31 +109,34 @@ assert isinstance(x, MyDtype[MyDuckArray, "3 4 5"])
# and that `x.dtype == "my_dtype"`
```
+**Any:** use `typing.Any` to check just the shape/dtype, but not the array type.
+
+**Unions:** these are unpacked. For example, `SomeDtype[A | B, "some shape"]` is equivalent to
+`SomeDtype[A, "some shape"] | SomeDtype[B, "some shape"]`.
+
**TypeVars:** in this case the runtime array is checked for matching the bounds or constraints of the `typing.TypeVar`.
-**Existing jaxtyped annotations:**
+**TypeAliasTypes:** Python 3.12 introduced the ability to write `type Foo = int | str`, in which case `Foo` is of type `typing.TypeAliasType`. In this case `SomeDtype[Foo, "some shape"]` corresponds to using the definition provided on the right hand side.
+
+**Existing jaxtyping annotations:**
```python
-Image = Float[Array, "channels height width"]
+Image = Float[jax.Array, "channels height width"]
BatchImage = Float[Image, "batch"]
```
in which case the additional shape is prepended, and the acceptable dtypes are the intersection of the two dtype specifiers used. (So that e.g. `BatchImage = Shaped[Image, "batch"]` would work just as well. But `Bool[Image, "batch"]` would throw an error, as there are no dtypes that are both bools and floats.) Thus the above is equivalent to
```python
-BatchImage = Float[Array, "batch channels height width"]
+BatchImage = Float[jax.Array, "batch channels height width"]
```
-Note that `jaxtyping.{Array, ArrayLike}` are only available if JAX has been installed.
-
-## Scalars, PRNG keys
+## JAX-specific types
-For convenience, jaxtyping also includes `jaxtyping.Scalar`, `jaxtyping.ScalarLike`, and `jaxtyping.PRNGKeyArray`, defined as:
-```python
-Scalar = Shaped[Array, ""]
-ScalarLike = Shaped[ArrayLike, ""]
-
-# Left: new-style typed keys; right: old-style keys. See JEP 9263.
-PRNGKeyArray = Union[Key[Array, ""], UInt32[Array, "2"]]
-```
+As `jaxtyping` originally got its start as a JAX-specific library, then we provide some JAX-specific types. These are all only available if JAX is installed.
-Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`, or e.g. an integer scalar with `Int[Scalar, ""]`.
+- `jaxtyping.Array`: alias for `jax.Array`
+- `jaxtyping.ArrayLike`: alias for `jax.typing.ArrayLike`
+- `jaxtyping.Scalar`: alias for `jaxtyping.Shaped[jax.Array, ""]`
+- `jaxtyping.ScalarLike`: alias for `jaxtyping.Shaped[jax.typing.ArrayLike, ""]`
+- `jaxtyping.Key`, which is the dtype of `jax.random.key`s. For example `jax.random.key(...)` produces a `jaxtyping.Key[jax.Array, ""]`.
+- `jaxtyping.PRNGKeyArray`: alias for `jaxtyping.Key[jax.Array, ""] | jaxtyping.UInt32[jax.Array, "2"]` (Left: new-style typed keys; right: old-style keys. See [JEP 9263](https://docs.jax.dev/en/latest/jep/9263-typed-keys.html).)
-Note that `jaxtyping.{Scalar, ScalarLike, PRNGKeyArray}` are only available if JAX has been installed.
+Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`.
diff --git a/docs/api/pytree.md b/docs/api/pytree.md
index d3a40ad..5e9b6ce 100644
--- a/docs/api/pytree.md
+++ b/docs/api/pytree.md
@@ -1,5 +1,7 @@
# PyTree annotations
+This is a JAX-specific feature, and is only available if JAX is installed.
+
:::jaxtyping.PyTree
options:
members: []
@@ -17,8 +19,8 @@
The prefix `?` may be used to indicate that the axis size can depend on which leaf of a PyTree the array is at. For example:
```python
def f(
- x: PyTree[Shaped[Array, "?foo"], "T"],
- y: PyTree[Shaped[Array, "?foo"], "T"],
+ x: PyTree[Shaped[jax.Array, "?foo"], "T"],
+ y: PyTree[Shaped[jax.Array, "?foo"], "T"],
):
pass
```
@@ -42,6 +44,3 @@ f((x1, x1), (y0, y1)) # x1 does not have a size matching y0!
Internally, all that is happening is that `foo` is replaced with `0foo` for the first leaf, `1foo` for the next leaf, etc., so that each leaf gets a unique version of the name.
----
-
-Note that `jaxtyping.{PyTree, PyTreeDef}` are only available if JAX has been installed.
diff --git a/docs/api/runtime-type-checking.md b/docs/api/runtime-type-checking.md
index ec0cd86..c3f442d 100644
--- a/docs/api/runtime-type-checking.md
+++ b/docs/api/runtime-type-checking.md
@@ -2,11 +2,9 @@
(See the [FAQ](../faq.md) for details on static type checking.)
-Runtime type checking **synergises beautifully with `jax.jit`!** All shape checks will be performed only whilst tracing, and will not impact runtime performance.
-
There are two approaches: either use [`jaxtyping.jaxtyped`][] to typecheck a single function, or [`jaxtyping.install_import_hook`][] to typecheck a whole codebase.
-In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). (If using typeguard, then specifically the version `2.*` series should be used. Later versions -- `3` and `4` -- have some known issues.)
+In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). (If using typeguard, then specifically the version `2.*` series should be used. Later versions – `3` and `4` – have some known issues.)
!!! warning
@@ -48,7 +46,11 @@ import jaxtyping
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype # or any other runtime type checker
```
-Place this at the start of your notebook -- everything that is directly defined in the notebook, after this magic is run, will be hook'd.
+Place this at the start of your notebook – everything that is directly defined in the notebook, after this magic is run, will be hook'd.
+
+#### Interaction with `jax.jit`
+
+Runtime type checking **synergises beautifully with `jax.jit`!** All shape checks will be performed only whilst tracing, and will not impact runtime performance.
#### Other runtime type-checking libraries
diff --git a/docs/faq.md b/docs/faq.md
index 97c448a..e63108c 100644
--- a/docs/faq.md
+++ b/docs/faq.md
@@ -18,7 +18,7 @@ In type annotations, strings are used for two different things. Sometimes they'r
Some tooling in the Python ecosystem assumes that only the latter is true, and will throw spurious errors if you try to use a string just as a string (like we do).
-In the case of `flake8`, or Ruff, this can be resolved. Multi-dimensional arrays (e.g. `Float32[Array, "b c"]`) will throw a very unusual error (F722, syntax error in forward annotation), so you can safely just disable this particular error globally. Uni-dimensional arrays (e.g. `Float32[Array, "x"]`) will throw an error that's actually useful (F821, undefined name), so instead of disabling this globally, you should instead prepend a space to the start of your shape, e.g. `Float32[Array, " x"]`. `jaxtyping` will treat this in the same way, whilst `flake8` will now throw an F722 error that you can disable as before.
+In the case of `flake8`, or Ruff, this can be resolved. Multi-dimensional arrays (e.g. `Float32[jax.Array, "b c"]`) will throw a very unusual error (F722, syntax error in forward annotation), so you can safely just disable this particular error globally. Uni-dimensional arrays (e.g. `Float32[jax.Array, "x"]`) will throw an error that's actually useful (F821, undefined name), so instead of disabling this globally, you should instead prepend a space to the start of your shape, e.g. `Float32[jax.Array, " x"]`. `jaxtyping` will treat this in the same way, whilst `flake8` will now throw an F722 error that you can disable as before.
## Dataclass annotations aren't being checked properly.
diff --git a/docs/index.md b/docs/index.md
index 57f6e74..2b243be 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,9 +1,19 @@
# Getting started
-jaxtyping is a library providing type annotations **and runtime type-checking** for:
+A library providing type annotations **and runtime type-checking** for the shape and dtype of JAX/PyTorch/NumPy/MLX/TensorFlow arrays and tensors.
-1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
-2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).
+_The name 'jax'typing is now historical, we support all of the above and have no JAX dependency!_
+
+```python
+from jaxtyping import Float
+from torch import Tensor
+
+# Accepts floating-point 2D arrays with matching axes
+def matrix_multiply(x: Float[Tensor, "dim1 dim2"],
+ y: Float[Tensor, "dim2 dim3"]
+ ) -> Float[Tensor, "dim1 dim3"]:
+ ...
+```
## Installation
@@ -13,29 +23,8 @@ pip install jaxtyping
Requires Python 3.10+.
-JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.
-
The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments).
-## Example
-
-```python
-from jaxtyping import Array, Float, PyTree
-
-# Accepts floating-point 2D arrays with matching axes
-# You can replace `Array` with `torch.Tensor` etc.
-def matrix_multiply(x: Float[Array, "dim1 dim2"],
- y: Float[Array, "dim2 dim3"]
- ) -> Float[Array, "dim1 dim3"]:
- ...
-
-def accepts_pytree_of_ints(x: PyTree[int]):
- ...
-
-def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
- ...
-```
-
## Next steps
Have a read of the [Array annotations](./api/array.md) documentation on the left-hand bar!
diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py
index 3478cf4..a19b0e5 100644
--- a/jaxtyping/_array_types.py
+++ b/jaxtyping/_array_types.py
@@ -357,10 +357,7 @@ def _check_scalar(dtype, dtypes, dims):
class AbstractArray(metaclass=_MetaAbstractArray):
"""This is the base class of all shape-and-dtype-specified arrays, e.g. it's a base
- class for `Float32[Array, "foo"]`.
-
- This might be useful if you're trying to inspect type annotations yourself, e.g.
- you can check `issubclass(annotation, jaxtyping.AbstractArray)`.
+ class for `Float32[jax.Array, "foo"]`.
"""
# This is what it was defined with.
@@ -461,7 +458,7 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
treepath = True
elem = elem[1:]
# Allow e.g. `foo=4` as an alternate syntax for just `4`, so that one
- # can write e.g. `Float[Array, "rows=3 cols=4"]`
+ # can write e.g. `Float[jax.Array, "rows=3 cols=4"]`
elif elem.count("=") == 1:
_, elem = elem.split("=")
else:
@@ -579,8 +576,8 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
if len(dtypes) == 0:
raise ValueError(
"A jaxtyping annotation cannot be extended with no overlapping "
- "dtypes. For example, `Bool[Float[Array, 'dim1'], 'dim2']` is an "
- "error. You probably want to make the outer wrapper be `Shaped`."
+ "dtypes. For example, `Bool[Float[jax.Array, 'dim1'], 'dim2']` is "
+ "an error. You probably want to make the outer wrapper be `Shaped`."
)
if array_type.index_variadic is not None:
if index_variadic is None:
@@ -645,8 +642,8 @@ def __getitem__(cls, item: tuple[Any, str]):
if not isinstance(item, tuple) or len(item) != 2:
raise ValueError(
"As of jaxtyping v0.2.0, type annotations must now include both an "
- "array type and a shape. For example `Float[Array, 'foo bar']`.\n"
- "Ellipsis can be used to accept any shape: `Float[Array, '...']`."
+ "array type and a shape. For example `Float[jax.Array, 'foo bar']`.\n"
+ "Ellipsis can be used to accept any shape: `Float[jax.Array, '...']`."
)
array_type, dim_str = item
dim_str = dim_str.strip()
@@ -701,11 +698,11 @@ class AbstractDtype(metaclass=_MetaAbstractDtype):
class UInt8or16(AbstractDtype):
dtypes = ["uint8", "uint16"]
- UInt8or16[Array, "shape"]
+ UInt8or16[jax.Array, "shape"]
```
which is essentially equivalent to
```python
- Union[UInt8[Array, "shape"], UInt16[Array, "shape"]]
+ UInt8[jax.Array, "shape"] | UInt16[jax.Array, "shape"]
```
"""
@@ -714,7 +711,7 @@ class UInt8or16(AbstractDtype):
def __init__(self, *args, **kwargs):
raise RuntimeError(
"AbstractDtype cannot be instantiated. Perhaps you wrote e.g. "
- '`Float32("shape")` when you mean `Float32[jnp.ndarray, "shape"]`?'
+ '`Float32("shape")` when you mean `Float32[jax.Array, "shape"]`?'
)
def __init_subclass__(cls, **kwargs):
diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py
index d50acd9..e32b020 100644
--- a/jaxtyping/_decorator.py
+++ b/jaxtyping/_decorator.py
@@ -100,18 +100,17 @@ def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
!!! Example
```python
- # Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
- from jaxtyping import Array, Float, jaxtyped
-
+ from torch import Tensor
+ from jaxtyping import Float, jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker
# Type-check a function
@jaxtyped(typechecker=typechecker)
- def batch_outer_product(x: Float[Array, "b c1"],
- y: Float[Array, "b c2"]
- ) -> Float[Array, "b c1 c2"]:
+ def batch_outer_product(x: Float[Tensor, "b c1"],
+ y: Float[Tensor, "b c2"]
+ ) -> Float[Tensor, "b c1 c2"]:
return x[:, :, None] * y[:, None, :]
# Type-check a dataclass
@@ -121,7 +120,7 @@ def batch_outer_product(x: Float[Array, "b c1"],
@dataclass
class MyDataclass:
x: int
- y: Float[Array, "b c"]
+ y: Float[Tensor, "b c"]
```
**Arguments:**
@@ -155,7 +154,7 @@ def f(x: int):
```python
@jaxtyped(typechecker=None)
def f(x):
- assert isinstance(x, Float[Array, "batch channel"])
+ assert isinstance(x, Float[Tensor, "batch channel"])
```
**Returns:**
@@ -166,7 +165,7 @@ def f(x):
If `fn` is a dataclass, then `fn` is returned directly, and additionally its
`__init__` method is wrapped and modified in-place.
- !!! Info "Old syntax"
+ ??? Info "Old syntax"
jaxtyping previously (before v0.2.24) recommended using this double-decorator
syntax:
@@ -184,7 +183,7 @@ def f(...): ...
**Dynamic contexts:**
- Put precisely, the axis names in e.g. `Float[Array, "batch channels"]` and the
+ Put precisely, the axis names in e.g. `Float[Tensor, "batch channels"]` and the
structure names in e.g. `PyTree[int, "T"]` are all scoped to the thread-local
dynamic context of a `jaxtyped`-wrapped function. If from within that function
we then call another `jaxtyped`-wrapped function, then a new context is pushed
@@ -196,7 +195,7 @@ def f(...): ...
**isinstance:**
Binding of a value against a name is done with an `isinstance` check, for
- example `isinstance(jnp.zeros((3, 4)), Float[Array, "dim1 dim2"])` will bind
+ example `isinstance(jnp.zeros((3, 4)), Float[Tensor, "dim1 dim2"])` will bind
`dim1=3` and `dim2=4`. In practice these `isinstance` checks are usually done by
the run-time typechecker `typechecker` that is supplied as an argument.
@@ -208,7 +207,7 @@ def f(...): ...
Only `isinstance` checks that pass will contribute to the store of values; those
that fail will not. As such it is safe to write e.g.
- `assert not isinstance(x, Float32[Array, "foo"])`.
+ `assert not isinstance(x, Float32[Tensor, "foo"])`.
**Decoupling contexts from function calls:**
@@ -222,7 +221,7 @@ def f(...): ...
supports being used as a context manager, by passing it the string `"context"`:
```python
with jaxtyped("context"):
- assert isinstance(x, Float[Array, "batch channel"])
+ assert isinstance(x, Float[Tensor, "batch channel"])
```
This is equivalent to placing this code inside a new function wrapped in
`jaxtyped(typechecker=None)`. Usage like this is very rare; it's mostly only
diff --git a/jaxtyping/_import_hook.py b/jaxtyping/_import_hook.py
index b933fc0..736bd6f 100644
--- a/jaxtyping/_import_hook.py
+++ b/jaxtyping/_import_hook.py
@@ -347,9 +347,10 @@ def install_import_hook(modules: Union[str, Sequence[str]], typechecker: Optiona
import main
### main.py
- from jaxtyping import Array, Float32
+ from torch import Tensor
+ from jaxtyping import Float32
- def f(x: Float32[Array, "batch channels"]):
+ def f(x: Float32[Tensor, "batch channels"]):
...
```
diff --git a/jaxtyping/_pytree_type.py b/jaxtyping/_pytree_type.py
index cacbb01..0ca77d8 100644
--- a/jaxtyping/_pytree_type.py
+++ b/jaxtyping/_pytree_type.py
@@ -293,7 +293,7 @@ def __pdoc__(self, **kwargs):
PyTree.__module__ = "jaxtyping"
else:
PyTree.__module__ = "builtins"
-PyTree.__doc__ = """Represents a PyTree.
+PyTree.__doc__ = """Represents a JAX PyTree.
Annotations of the following sorts are supported:
```python
@@ -312,7 +312,7 @@ def __pdoc__(self, **kwargs):
([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
- example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
+ example, `PyTree[int]` or `PyTree[str | Float32[jax.Array, "b c"]]`.
c. A structure name can also be passed. In this case
`jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
diff --git a/jaxtyping/_storage.py b/jaxtyping/_storage.py
index 6960883..651b265 100644
--- a/jaxtyping/_storage.py
+++ b/jaxtyping/_storage.py
@@ -116,12 +116,12 @@ def print_bindings():
```python
@jaxtyped(typechecker=...)
- def f(x: Float[Array, "foo bar"]):
+ def f(x: Float[jax.Array, "foo bar"]):
print_bindings()
...
```
- noting that these values are bounding during runtime typechecking, so that the
+ noting that these values are bound during runtime typechecking, so that the
[`jaxtyping.jaxtyped`][] decorator is required.
**Arguments:**
diff --git a/mkdocs.yml b/mkdocs.yml
index f2c7751..4bf7329 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -98,7 +98,7 @@ nav:
- 'index.md'
- API:
- 'api/array.md'
- - 'api/pytree.md'
- 'api/runtime-type-checking.md'
- 'api/advanced-features.md'
+ - 'api/pytree.md'
- 'faq.md'
diff --git a/pyproject.toml b/pyproject.toml
index 0007d47..4ef59a1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,8 +27,8 @@ license = {file = "LICENSE"}
name = "jaxtyping"
readme = "README.md"
requires-python = ">=3.10"
-urls = {repository = "https://github.com/google/jaxtyping"}
-version = "0.3.4"
+urls = {repository = "https://github.com/patrick-kidger/jaxtyping"}
+version = "0.3.5"
[project.optional-dependencies]
dev = [