Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
<h1 align="center">jaxtyping</h1>

Type annotations **and runtime type-checking** for:
A library providing type annotations **and runtime type-checking** for the shape and dtype of JAX/PyTorch/NumPy/MLX/TensorFlow arrays and tensors.

1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).
_The name 'jax'typing is now historical, we support all of the above and have no JAX dependency!_

**For example:**
```python
from jaxtyping import Array, Float, PyTree
from jaxtyping import Float
from torch import Tensor

# Accepts floating-point 2D arrays with matching axes
# You can replace `Array` with `torch.Tensor` etc.
def matrix_multiply(x: Float[Array, "dim1 dim2"],
y: Float[Array, "dim2 dim3"]
) -> Float[Array, "dim1 dim3"]:
...

def accepts_pytree_of_ints(x: PyTree[int]):
...

def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
def matrix_multiply(x: Float[Tensor, "dim1 dim2"],
y: Float[Tensor, "dim2 dim3"]
) -> Float[Tensor, "dim1 dim3"]:
...
```

Expand All @@ -31,8 +23,6 @@ pip install jaxtyping

Requires Python 3.10+.

JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.

The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments).

## Documentation
Expand Down
14 changes: 10 additions & 4 deletions docs/api/advanced-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@

## Introspection

If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.
::: jaxtyping.AbstractArray
options:
members: []

!!! info

If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.

You can check for dtypes by doing `issubclass(x, AbstractDtype)`. For example, `issubclass(Float32, AbstractDtype)` will pass.
You can check for dtypes by doing `issubclass(x, AbstractDtype)`. For example, `issubclass(Float32, AbstractDtype)` will pass.

You can check for arrays by doing `issubclass(x, AbstractArray)`. Here, `AbstractArray` is the base class for all shape-and-dtype specified arrays, e.g. it's a base class for `Float32[Array, "foo"]`.
You can check for arrays by doing `issubclass(x, AbstractArray)`., For example, `issubclass(Float32[jax.Array, "some shape"], AbstractArray)` will pass.

You can check for pytrees by doing `issubclass(x, PyTree)`. For example, `issubclass(PyTree[int], PyTree)` will pass.
You can check for pytrees by doing `issubclass(x, PyTree)`. For example, `issubclass(PyTree[int], PyTree)` will pass.
119 changes: 55 additions & 64 deletions docs/api/array.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
# Array annotations

The shape and dtypes of arrays can be annotated in the form `dtype[array, shape]`, such as `Float[Array, "batch channels"]`.
The shape and dtypes of arrays can be annotated in the form `dtype[array, shape]`, such as
`jaxtyping.Float[torch.Tensor, "batch channels"]`.

## Dtype

The dtype should be any one of (all imported from `jaxtyping`):

- Any dtype at all: `Shaped`
- Boolean: `Bool`
- Any integer, unsigned integer, floating, or complex: `Num`
- Any floating or complex: `Inexact`
- Any floating point: `Float`
- Of particular precision: `BFloat16`, `Float16`, `Float32`, `Float64`
- Any complex: `Complex`
- Of particular precision: `Complex64`, `Complex128`
- Any integer or unsigned integer: `Integer`
- Any unsigned integer: `UInt`
- Of particular precision: `UInt2`, `UInt4`, `UInt8`, `UInt16`, `UInt32`, `UInt64`
- Any signed integer: `Int`
- Of particular precision: `Int2`, `Int4`, `Int8`, `Int16`, `Int32`, `Int64`
- Any floating, integer, or unsigned integer: `Real`.

## Shape

Expand All @@ -11,7 +31,7 @@ The shape should be a string of space-separated symbols, such as `"a b c d"`. Ea
- `int`: fixed-size axis, e.g. `"28 28"`.
- `str`: variable-size axis, e.g. `"channels"`.
- A symbolic expression in terms of other variable-size axes, e.g.
`def remove_last(x: Float[Array, "dim"]) -> Float[Array, "dim-1"]`.
`def remove_last(x: Float[torch.Tensor, "dim"]) -> Float[torch.Tensor, "dim-1"]`.
Symbolic expressions must not use any spaces, otherwise each piece is treated as as a separate axis.

When calling a function, variable-size axes and symbolic axes will be matched up across all arguments and checked for consistency. (See [Runtime type checking](./runtime-type-checking.md).)
Expand All @@ -21,10 +41,10 @@ When calling a function, variable-size axes and symbolic axes will be matched up
In addition some modifiers can be applied:

- Prepend `*` to an axis to indicate that it can match multiple axes, e.g. `"*batch"` will match zero or more batch axes.
- Prepend `#` to an axis to indicate that it can be that size *or* equal to one -- i.e. broadcasting is acceptable, e.g.
`def add(x: Float[Array, "#foo"], y: Float[Array, "#foo"]) -> Float[Array, "#foo"]`.
- Prepend `#` to an axis to indicate that it can be that size *or* equal to one i.e. broadcasting is acceptable, e.g.
`def add(x: Float[torch.Tensor, "#foo"], y: Float[torch.Tensor, "#foo"]) -> Float[torch.Tensor, "#foo"]`.
- Prepend `_` to an axis to disable any runtime checking of that axis (so that it can be used just as documentation). This can also be used as just `_` on its own: e.g. `"b c _ _"`.
- Documentation-only names (i.e. they're ignored by jaxtyping) can be handled by prepending a name followed by `=` e.g. `Float[Array, "rows=4 cols=3"]`.
- Documentation-only names (i.e. they're ignored by jaxtyping) can be handled by prepending a name followed by `=` e.g. `Float[torch.Tensor, "rows=4 cols=3"]`.
- Prepend `?` to an axis to indicate that its size can vary within a PyTree structure. (See [PyTree annotations](./pytree.md).)

When using multiple modifiers, their order does not matter.
Expand All @@ -35,72 +55,40 @@ As a special case:

**Notes**

- To denote a scalar shape use `""`, e.g. `Float[Array, ""]`.
- To denote an arbitrary shape (and only check dtype) use `"..."`, e.g. `Float[Array, "..."]`.
- To denote a scalar shape use `""`, e.g. `Float[torch.Tensor, ""]`.
- To denote an arbitrary shape (and only check dtype) use `"..."`, e.g. `Float[torch.Tensor, "..."]`.
- You cannot have more than one use of multiple-axes, i.e. you can only use `...` or `*name` at most once in each array.
- A symbolic expression cannot be evaluated unless all of the axes sizes it refers to have already been processed. In practice this usually means that they should only be used in annotations for the return type, and only use axes declared in the arguments.
- Symbolic expressions are evaluated in two stages: they are first evaluated as f-strings using the arguments of the function, and second are evaluated using the processed axis sizes. The f-string evaluation means that they can use local variables by enclosing them with curly braces, e.g. `{variable}`, e.g.
```python
def full(size: int, fill: float) -> Float[Array, "{size}"]:
def full(size: int, fill: float) -> Float[jax.Array, "{size}"]:
return jax.numpy.full((size,), fill)

class SomeClass:
some_value = 5

def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]:
def full(self, fill: float) -> Float[jax.Array, "{self.some_value}+3"]:
return jax.numpy.full((self.some_value + 3,), fill)
```

## Dtype

The dtype should be any one of (all imported from `jaxtyping`):
## Array

- Any dtype at all: `Shaped`
- Boolean: `Bool`
- PRNG key: `Key`
- Any integer, unsigned integer, floating, or complex: `Num`
- Any floating or complex: `Inexact`
- Any floating point: `Float`
- Of particular precision: `BFloat16`, `Float16`, `Float32`, `Float64`
- Any complex: `Complex`
- Of particular precision: `Complex64`, `Complex128`
- Any integer or unsigned intger: `Integer`
- Any unsigned integer: `UInt`
- Of particular precision: `UInt2`, `UInt4`, `UInt8`, `UInt16`, `UInt32`, `UInt64`
- Any signed integer: `Int`
- Of particular precision: `Int2`, `Int4`, `Int8`, `Int16`, `Int32`, `Int64`
- Any floating, integer, or unsigned integer: `Real`.
A variety of types are supported here:

Unless you really want to force a particular precision, then for most applications you should probably allow any floating-point, any integer, etc. That is, use
```python
from jaxtyping import Array, Float
Float[Array, "some_shape"]
```
rather than
```python
from jaxtyping import Array, Float32
Float32[Array, "some_shape"]
```
**Arrays and Tensors:**

## Array
The following frameworks are supported:

The array should typically be either one of:
```python
jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one another
jax.Array / jax.numpy.ndarray # these are both aliases of one another
np.ndarray
torch.Tensor
tf.Tensor
mx.array
```
That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow + MLX.

Some other types are also supported here:
_Despite the now-historical name, 'jax'typing also supports NumPy + PyTorch + TensorFlow + MLX._

**Unions:** these are unpacked. For example, `SomeDtype[Union[A, B], "some shape"]` is equivalent to `Union[SomeDtype[A, "some shape"], SomeDtype[B, "some shape"]]`. A common example of a union type here is `np.typing.ArrayLike`.

**Any:** use `typing.Any` to check just the shape/dtype, but not the array type.

**Duck-type arrays:** anything with `.shape` and `.dtype` attributes. For example,
**Duck-type arrays:** anything with `.shape` and `.dtype` attributes. The shape should be a `tuple[int, ...]` and the dtype should be a `str`. For example,
```python
class MyDuckArray:
@property
Expand All @@ -121,31 +109,34 @@ assert isinstance(x, MyDtype[MyDuckArray, "3 4 5"])
# and that `x.dtype == "my_dtype"`
```

**Any:** use `typing.Any` to check just the shape/dtype, but not the array type.

**Unions:** these are unpacked. For example, `SomeDtype[A | B, "some shape"]` is equivalent to
`SomeDtype[A, "some shape"] | SomeDtype[B, "some shape"]`.

**TypeVars:** in this case the runtime array is checked for matching the bounds or constraints of the `typing.TypeVar`.

**Existing jaxtyped annotations:**
**TypeAliasTypes:** Python 3.12 introduced the ability to write `type Foo = int | str`, in which case `Foo` is of type `typing.TypeAliasType`. In this case `SomeDtype[Foo, "some shape"]` corresponds to using the definition provided on the right hand side.

**Existing jaxtyping annotations:**
```python
Image = Float[Array, "channels height width"]
Image = Float[jax.Array, "channels height width"]
BatchImage = Float[Image, "batch"]
```
in which case the additional shape is prepended, and the acceptable dtypes are the intersection of the two dtype specifiers used. (So that e.g. `BatchImage = Shaped[Image, "batch"]` would work just as well. But `Bool[Image, "batch"]` would throw an error, as there are no dtypes that are both bools and floats.) Thus the above is equivalent to
```python
BatchImage = Float[Array, "batch channels height width"]
BatchImage = Float[jax.Array, "batch channels height width"]
```

Note that `jaxtyping.{Array, ArrayLike}` are only available if JAX has been installed.

## Scalars, PRNG keys
## JAX-specific types

For convenience, jaxtyping also includes `jaxtyping.Scalar`, `jaxtyping.ScalarLike`, and `jaxtyping.PRNGKeyArray`, defined as:
```python
Scalar = Shaped[Array, ""]
ScalarLike = Shaped[ArrayLike, ""]

# Left: new-style typed keys; right: old-style keys. See JEP 9263.
PRNGKeyArray = Union[Key[Array, ""], UInt32[Array, "2"]]
```
As `jaxtyping` originally got its start as a JAX-specific library, then we provide some JAX-specific types. These are all only available if JAX is installed.

Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`, or e.g. an integer scalar with `Int[Scalar, ""]`.
- `jaxtyping.Array`: alias for `jax.Array`
- `jaxtyping.ArrayLike`: alias for `jax.typing.ArrayLike`
- `jaxtyping.Scalar`: alias for `jaxtyping.Shaped[jax.Array, ""]`
- `jaxtyping.ScalarLike`: alias for `jaxtyping.Shaped[jax.typing.ArrayLike, ""]`
- `jaxtyping.Key`, which is the dtype of `jax.random.key`s. For example `jax.random.key(...)` produces a `jaxtyping.Key[jax.Array, ""]`.
- `jaxtyping.PRNGKeyArray`: alias for `jaxtyping.Key[jax.Array, ""] | jaxtyping.UInt32[jax.Array, "2"]` (Left: new-style typed keys; right: old-style keys. See [JEP 9263](https://docs.jax.dev/en/latest/jep/9263-typed-keys.html).)

Note that `jaxtyping.{Scalar, ScalarLike, PRNGKeyArray}` are only available if JAX has been installed.
Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`.
9 changes: 4 additions & 5 deletions docs/api/pytree.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# PyTree annotations

This is a JAX-specific feature, and is only available if JAX is installed.

:::jaxtyping.PyTree
options:
members: []
Expand All @@ -17,8 +19,8 @@
The prefix `?` may be used to indicate that the axis size can depend on which leaf of a PyTree the array is at. For example:
```python
def f(
x: PyTree[Shaped[Array, "?foo"], "T"],
y: PyTree[Shaped[Array, "?foo"], "T"],
x: PyTree[Shaped[jax.Array, "?foo"], "T"],
y: PyTree[Shaped[jax.Array, "?foo"], "T"],
):
pass
```
Expand All @@ -42,6 +44,3 @@ f((x1, x1), (y0, y1)) # x1 does not have a size matching y0!

Internally, all that is happening is that `foo` is replaced with `0foo` for the first leaf, `1foo` for the next leaf, etc., so that each leaf gets a unique version of the name.

---

Note that `jaxtyping.{PyTree, PyTreeDef}` are only available if JAX has been installed.
10 changes: 6 additions & 4 deletions docs/api/runtime-type-checking.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

(See the [FAQ](../faq.md) for details on static type checking.)

Runtime type checking **synergises beautifully with `jax.jit`!** All shape checks will be performed only whilst tracing, and will not impact runtime performance.

There are two approaches: either use [`jaxtyping.jaxtyped`][] to typecheck a single function, or [`jaxtyping.install_import_hook`][] to typecheck a whole codebase.

In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). (If using typeguard, then specifically the version `2.*` series should be used. Later versions -- `3` and `4` -- have some known issues.)
In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). (If using typeguard, then specifically the version `2.*` series should be used. Later versions `3` and `4` have some known issues.)

!!! warning

Expand Down Expand Up @@ -48,7 +46,11 @@ import jaxtyping
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype # or any other runtime type checker
```
Place this at the start of your notebook -- everything that is directly defined in the notebook, after this magic is run, will be hook'd.
Place this at the start of your notebook – everything that is directly defined in the notebook, after this magic is run, will be hook'd.

#### Interaction with `jax.jit`

Runtime type checking **synergises beautifully with `jax.jit`!** All shape checks will be performed only whilst tracing, and will not impact runtime performance.

#### Other runtime type-checking libraries

Expand Down
2 changes: 1 addition & 1 deletion docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ In type annotations, strings are used for two different things. Sometimes they'r

Some tooling in the Python ecosystem assumes that only the latter is true, and will throw spurious errors if you try to use a string just as a string (like we do).

In the case of `flake8`, or Ruff, this can be resolved. Multi-dimensional arrays (e.g. `Float32[Array, "b c"]`) will throw a very unusual error (F722, syntax error in forward annotation), so you can safely just disable this particular error globally. Uni-dimensional arrays (e.g. `Float32[Array, "x"]`) will throw an error that's actually useful (F821, undefined name), so instead of disabling this globally, you should instead prepend a space to the start of your shape, e.g. `Float32[Array, " x"]`. `jaxtyping` will treat this in the same way, whilst `flake8` will now throw an F722 error that you can disable as before.
In the case of `flake8`, or Ruff, this can be resolved. Multi-dimensional arrays (e.g. `Float32[jax.Array, "b c"]`) will throw a very unusual error (F722, syntax error in forward annotation), so you can safely just disable this particular error globally. Uni-dimensional arrays (e.g. `Float32[jax.Array, "x"]`) will throw an error that's actually useful (F821, undefined name), so instead of disabling this globally, you should instead prepend a space to the start of your shape, e.g. `Float32[jax.Array, " x"]`. `jaxtyping` will treat this in the same way, whilst `flake8` will now throw an F722 error that you can disable as before.

## Dataclass annotations aren't being checked properly.

Expand Down
37 changes: 13 additions & 24 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
# Getting started

jaxtyping is a library providing type annotations **and runtime type-checking** for:
A library providing type annotations **and runtime type-checking** for the shape and dtype of JAX/PyTorch/NumPy/MLX/TensorFlow arrays and tensors.

1. shape and dtype of [JAX](https://github.com/google/jax) arrays; *(Now also supports PyTorch, NumPy, MLX, and TensorFlow!)*
2. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).
_The name 'jax'typing is now historical, we support all of the above and have no JAX dependency!_

```python
from jaxtyping import Float
from torch import Tensor

# Accepts floating-point 2D arrays with matching axes
def matrix_multiply(x: Float[Tensor, "dim1 dim2"],
y: Float[Tensor, "dim2 dim3"]
) -> Float[Tensor, "dim1 dim3"]:
...
```

## Installation

Expand All @@ -13,29 +23,8 @@ pip install jaxtyping

Requires Python 3.10+.

JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.

The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments).

## Example

```python
from jaxtyping import Array, Float, PyTree

# Accepts floating-point 2D arrays with matching axes
# You can replace `Array` with `torch.Tensor` etc.
def matrix_multiply(x: Float[Array, "dim1 dim2"],
y: Float[Array, "dim2 dim3"]
) -> Float[Array, "dim1 dim3"]:
...

def accepts_pytree_of_ints(x: PyTree[int]):
...

def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
...
```

## Next steps

Have a read of the [Array annotations](./api/array.md) documentation on the left-hand bar!
Expand Down
Loading