Skip to content

Commit 705a6c3

Browse files
committed
Add symbolic PyTree annotations, expand tox/CI with jax 0.7-0.9 and optree/beartype compat
Replace string-based PyTree structure specs ("T", "S T", "T ...") with symbolic Structure class and pre-exported T/S symbols. Add composite (S[T]), prefix (T[...]), and suffix (..., T) operators. Include optree as optional dependency with full version compat matrix (0.14-0.19), beartype compat (0.20-0.22), and jax 0.5-0.9.
1 parent 042011a commit 705a6c3

File tree

16 files changed

+1418
-69
lines changed

16 files changed

+1418
-69
lines changed

.github/workflows/ci.yml

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,24 +47,71 @@ jobs:
4747
fail-fast: false
4848
matrix:
4949
python: ["3.12", "3.13", "3.14"]
50-
suffix: ["", "-torch", "-jax", "-all"]
51-
name: test (py${{ matrix.python }}${{ matrix.suffix }})
50+
name: test (py${{ matrix.python }})
5251
steps:
5352
- uses: actions/checkout@v6
5453
- uses: astral-sh/setup-uv@v7
55-
- run: uv run tox -e "py${{ matrix.python }}${{ matrix.suffix }}"
54+
- run: uv run tox -e "py${{ matrix.python }}"
5655
env:
5756
TOX_PYTHON: python${{ matrix.python }}
5857

58+
numpy-compat:
59+
runs-on: ubuntu-latest
60+
strategy:
61+
fail-fast: false
62+
matrix:
63+
version: ["numpy22", "numpy23", "numpy24"]
64+
name: ${{ matrix.version }}
65+
steps:
66+
- uses: actions/checkout@v6
67+
- uses: astral-sh/setup-uv@v7
68+
- run: uv run tox -e "${{ matrix.version }}"
69+
70+
jax-compat:
71+
runs-on: ubuntu-latest
72+
strategy:
73+
fail-fast: false
74+
matrix:
75+
version: ["jax05", "jax06", "jax07", "jax08", "jax09"]
76+
name: ${{ matrix.version }}
77+
steps:
78+
- uses: actions/checkout@v6
79+
- uses: astral-sh/setup-uv@v7
80+
- run: uv run tox -e "${{ matrix.version }}"
81+
82+
torch-compat:
83+
runs-on: ubuntu-latest
84+
strategy:
85+
fail-fast: false
86+
matrix:
87+
version: ["torch26", "torch27", "torch28", "torch29", "torch210"]
88+
name: ${{ matrix.version }}
89+
steps:
90+
- uses: actions/checkout@v6
91+
- uses: astral-sh/setup-uv@v7
92+
- run: uv run tox -e "${{ matrix.version }}"
93+
5994
beartype-compat:
6095
runs-on: ubuntu-latest
6196
strategy:
6297
fail-fast: false
6398
matrix:
6499
version: ["bt020", "bt021", "bt022"]
65-
suffix: ["", "-torch", "-jax"]
66-
name: beartype (${{ matrix.version }}${{ matrix.suffix }})
100+
suffix: ["", "-jax", "-torch"]
101+
name: ${{ matrix.version }}${{ matrix.suffix }}
67102
steps:
68103
- uses: actions/checkout@v6
69104
- uses: astral-sh/setup-uv@v7
70105
- run: uv run tox -e "${{ matrix.version }}${{ matrix.suffix }}"
106+
107+
optree-compat:
108+
runs-on: ubuntu-latest
109+
strategy:
110+
fail-fast: false
111+
matrix:
112+
version: ["optree014", "optree015", "optree016", "optree017", "optree018", "optree019"]
113+
name: ${{ matrix.version }}
114+
steps:
115+
- uses: actions/checkout@v6
116+
- uses: astral-sh/setup-uv@v7
117+
- run: uv run tox -e "${{ matrix.version }}"

README.md

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ Shapix turns array shape annotations into **Python objects** that beartype valid
3030
pip install shapix
3131
```
3232

33-
Optional backends:
33+
Shapix has one dependency: [beartype](https://github.com/beartype/beartype). Install your preferred array framework separately:
3434

3535
```bash
36-
pip install shapix[jax] # JAX support
37-
pip install shapix[torch] # PyTorch support
36+
pip install shapix numpy # NumPy
37+
pip install shapix torch # PyTorch
38+
pip install shapix jax # JAX
39+
pip install shapix numpy optree # NumPy + PyTree support
3840
```
3941

4042
## Quick start
@@ -101,7 +103,7 @@ Bind to a size on first occurrence and enforce consistency on subsequent ones.
101103
| `H` | Height |
102104
| `W` | Width |
103105
| `L` | Sequence length |
104-
| `T` | Time / tree |
106+
| `T` | Tree structure |
105107
| `P` | Points / parameters |
106108

107109
```python
@@ -322,6 +324,70 @@ BF16_OR_F32 = DtypeSpec("BF16orF32", frozenset({"bfloat16", "float32"}))
322324
MixedArray = make_array_type(np.ndarray, BF16_OR_F32)
323325
```
324326

327+
## PyTree annotations
328+
329+
PyTree annotations validate all leaves in a nested structure (dicts, lists, tuples, namedtuples). Requires `optree` (`pip install optree`).
330+
331+
### Basic leaf checking
332+
333+
```python
334+
from shapix import PyTree, T, S, N, C
335+
from shapix.numpy import F32
336+
337+
@beartype
338+
def process(data: PyTree[F32[N, C]]) -> PyTree[F32[N, C]]:
339+
...
340+
341+
# All leaves must be F32 arrays with consistent N and C
342+
process({"params": np.ones((3, 4), dtype=np.float32),
343+
"state": np.ones((3, 4), dtype=np.float32)})
344+
```
345+
346+
### Structure binding
347+
348+
Named structure symbols (`T`, `S`) enforce that multiple arguments share identical tree shapes:
349+
350+
```python
351+
@beartype
352+
def add_trees(x: PyTree[F32[N], T], y: PyTree[F32[N], T]) -> PyTree[F32[N]]:
353+
...
354+
355+
add_trees({"a": x1, "b": x2}, {"a": y1, "b": y2}) # OK — same structure
356+
add_trees({"a": x1}, [y1, y2]) # Raises — different structure
357+
```
358+
359+
### Composite structures
360+
361+
Use `S[T]` for nested structure matching — outer structure S, inner structure T:
362+
363+
```python
364+
def f(x: PyTree[int, T], y: PyTree[int, S], z: PyTree[int, S[T]]): ...
365+
```
366+
367+
### Prefix and suffix wildcards
368+
369+
```python
370+
# T[...] — top-level structure matches T, subtrees are arbitrary
371+
def f(x: PyTree[F32[N], T[...]], y: PyTree[F32[N], T[...]]): ...
372+
373+
# ..., T — full structure matches T
374+
def f(x: PyTree[F32[N], ..., T], y: PyTree[F32[N], ..., T]): ...
375+
```
376+
377+
### Custom structure symbols
378+
379+
Create your own with `Structure`:
380+
381+
```python
382+
from shapix import Structure
383+
384+
Params = Structure("Params")
385+
State = Structure("State")
386+
387+
@beartype
388+
def train(params: PyTree[F32[N], Params], state: PyTree[I64[N], State]): ...
389+
```
390+
325391
## Advanced usage
326392

327393
### Package-wide instrumentation with `beartype.claw`
@@ -340,28 +406,29 @@ shapix_this_package()
340406

341407
Every function in your package that uses shapix type annotations will be checked automatically.
342408

343-
### Explicit memo management with `@shapix.check`
409+
### How cross-argument checking works (the memo)
344410

345-
The frame-based memo works automatically with `@beartype` in virtually all cases. For exotic call-stack scenarios, or to combine memo management with custom `BeartypeConf`, use the explicit decorator:
411+
A **dimension memo** maps dimension names to sizes (e.g., `N→4`, `C→3`). Each function call gets a fresh memo. All parameter checks within that call share the same memo — that's how shapix knows `N=4` in `x` must match `N=4` in `y`.
346412

347-
```python
348-
import shapix
349-
from beartype import beartype
413+
This happens **automatically** with `@beartype`. Shapix detects the beartype wrapper frame via `sys._getframe()` and associates a memo with it. No extra decorator needed.
350414

351-
# Option 1: Memo only — pair with @beartype
352-
@shapix.check
353-
@beartype
354-
def f(x: F32[N, C]) -> F32[N, C]:
355-
...
415+
### `@shapix.check` (optional)
356416

357-
# Option 2: Memo + beartype combined with custom config
358-
from beartype._conf.confmain import BeartypeConf
417+
`@shapix.check` provides **explicit** memo management. It's useful in two scenarios:
418+
419+
1. **Combining memo with custom `BeartypeConf`** — a single decorator instead of stacking two:
420+
421+
```python
422+
from beartype import BeartypeConf
359423

360424
@shapix.check(conf=BeartypeConf())
361-
def f(x: F32[N, C]) -> F32[N, C]:
362-
...
425+
def f(x: F32[N, C]) -> F32[N, C]: ...
363426
```
364427

428+
2. **Exotic call stacks** where frame-based detection is unreliable (deep decorator chains, recursive wrappers).
429+
430+
For normal usage, just use `@beartype` — it works out of the box.
431+
365432
### Manual checks with `check_context`
366433

367434
For `isinstance`-style checks outside of decorated functions, use `check_context` with beartype's `is_bearable`:
@@ -383,7 +450,7 @@ Shapix uses three key mechanisms:
383450

384451
1. **`Annotated[T, Is[validator]]`** — Each array type annotation (e.g., `F32[N, C]`) produces a `typing.Annotated` type with a beartype `Is[...]` validator. This lets beartype handle all the dispatch natively.
385452

386-
2. **Frame-based memo management** — beartype's `Is[validator]` call stack is deterministic: `validator → _is_valid_bool → beartype_wrapper`. All parameter checks for one function call share the same wrapper frame. Shapix identifies this frame via `sys._getframe()` and associates a dimension memo (name → size bindings) with it. This is how cross-argument consistency works with zero boilerplate.
453+
2. **Frame-based memo** — beartype's `Is[validator]` call stack is deterministic: `validator → _is_valid_bool → beartype_wrapper`. All parameter checks for one function call share the same wrapper frame. Shapix identifies this frame via `sys._getframe()` and associates a dimension memo (name → size bindings) with it. This is how cross-argument consistency works with zero boilerplate.
387454

388455
3. **Thread-local storage** — Each thread gets its own memo stack via `threading.local()`, ensuring thread safety.
389456

@@ -453,6 +520,7 @@ def f(x: F32[~B, C]) -> F32[~B, C]: # type: ignore[reportInvalidTypeForm]
453520
| BeartypeConf | Not supported (decorator conflict) | Fully supported |
454521
| Type checker | Metaclass magic (confuses pyright) | `Annotated` aliases (clean) |
455522
| Backends | NumPy, JAX | NumPy, JAX, PyTorch |
523+
| PyTree | Built-in with structure binding | Built-in with structure binding (via optree) |
456524
| Dependencies | jaxtyping + beartype | beartype only |
457525
| Custom decorator | Required | Not required |
458526

pyproject.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@ description = "Add your description here"
55
readme = "README.md"
66
authors = [{ name = "acecchini", email = "ale.cecchini.valette@gmail.com" }]
77
requires-python = ">=3.12"
8-
dependencies = ["numpy>=2.2.0", "beartype>=0.22.9"]
8+
dependencies = ["beartype>=0.22.9"]
99

1010
[dependency-groups]
11-
# CPU-only variants for dev/CI testing. Shapix is device-agnostic — it checks
12-
# .shape and .dtype only, so GPU/TPU builds (torch+cuda, jax[cuda12], etc.)
13-
# work without any extra configuration. Users install their own backend.
14-
torch = ["torch>=2.6.0"]
15-
jax = ["jax[cpu]>=0.5.0"]
11+
# Backends are CPU-only for dev/CI. Shapix is device-agnostic — it checks
12+
# .shape and .dtype only. Users install their own numpy/jax/torch/optree.
1613
dev = [
14+
"numpy>=2.2.0",
15+
"torch>=2.6.0",
16+
"jax[cpu]>=0.5.0",
17+
"optree>=0.14.0",
1718
"ruff>=0.15.4",
1819
"pyright>=1.1.408",
1920
"mypy>=1.19.1",

src/shapix/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@ def conv(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
2020
Exports
2121
-------
2222
Dimension symbols
23-
``B``, ``N``, ``P``, ``L``, ``C``, ``H``, ``W``, ``T`` — named dimensions.
23+
``B``, ``N``, ``P``, ``L``, ``C``, ``H``, ``W`` — named dimensions.
2424
``__`` — anonymous (match any single dim, no binding).
2525
``Scalar`` — scalar (no dimensions).
2626
27+
PyTree structure symbols
28+
``T``, ``S`` — named tree structure symbols.
29+
:class:`Structure` — create custom structure symbols.
30+
2731
Unary operators (apply to any dimension)
2832
``~N`` — variadic (match zero or more contiguous dims).
2933
``+N`` — broadcastable (size 1 always matches).
@@ -32,6 +36,7 @@ def conv(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
3236
Classes
3337
:class:`Dimension` — create custom dimension symbols with arithmetic support.
3438
:class:`DtypeSpec` — describe a set of allowed dtypes by string name.
39+
:class:`PyTree` — subscriptable pytree annotation (requires ``optree``).
3540
3641
Functions
3742
:func:`make_array_type` — create subscriptable array type factories for
@@ -46,6 +51,10 @@ def conv(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
4651
from ._array_types import make_array_type as make_array_type
4752
from ._decorator import check as check
4853
from ._decorator import check_context as check_context
54+
from ._pytree import PyTree as PyTree
55+
from ._pytree import S as S
56+
from ._pytree import Structure as Structure
57+
from ._pytree import T as T
4958
from ._dimensions import __ as __
5059
from ._dimensions import B as B
5160
from ._dimensions import C as C
@@ -55,6 +64,5 @@ def conv(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
5564
from ._dimensions import N as N
5665
from ._dimensions import P as P
5766
from ._dimensions import Scalar as Scalar
58-
from ._dimensions import T as T
5967
from ._dimensions import W as W
6068
from ._dtypes import DtypeSpec as DtypeSpec

src/shapix/_array_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __call__(self, obj: object) -> bool:
8080
# the memo with partial bindings from a bad argument).
8181
single_snap = memo.single.copy()
8282
variadic_snap = memo.variadic.copy()
83+
structures_snap = memo.structures.copy()
8384

8485
result = check_shape(tuple(shape), self._shape_spec, memo) == ""
8586

@@ -88,6 +89,8 @@ def __call__(self, obj: object) -> bool:
8889
memo.single.update(single_snap)
8990
memo.variadic.clear()
9091
memo.variadic.update(variadic_snap)
92+
memo.structures.clear()
93+
memo.structures.update(structures_snap)
9194
self._fail_id = obj_id
9295

9396
return result

src/shapix/_dimensions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def flatten(x: F32[N, C]) -> F32[N * C]: ...
5353
"C",
5454
"H",
5555
"W",
56-
"T",
5756
# Anonymous
5857
"__",
5958
# Special
@@ -198,7 +197,6 @@ def _dim_spec(self) -> DimSpec | None:
198197
C: Dimension
199198
H: Dimension
200199
W: Dimension
201-
T: Dimension
202200
__: Dimension
203201
else:
204202
# Common named dimensions
@@ -210,7 +208,6 @@ def _dim_spec(self) -> DimSpec | None:
210208
C = Dimension("C")
211209
H = Dimension("H")
212210
W = Dimension("W")
213-
T = Dimension("T")
214211

215212
# Anonymous (match anything, no binding)
216213
__ = Dimension("__")

src/shapix/_memo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class ShapeMemo:
3232
variadic: dict[str, tuple[bool, tuple[int, ...]]] = field(default_factory=dict)
3333
"""Variadic dimension bindings: ``{"spatial": (False, (28, 28))}``."""
3434

35+
structures: dict[str, object] = field(default_factory=dict)
36+
"""PyTree structure bindings: ``{"T": <PyTreeSpec>}``."""
37+
3538

3639
_local = threading.local()
3740

0 commit comments

Comments
 (0)