Skip to content

Commit 8e6c5ae

Browse files
committed
Rename PyTree → Tree, add backend tree modules, modernize to Python 3.12+
- Rename _pytree.py → _tree.py, PyTree → Tree, _PyTreeChecker → _TreeChecker, _PyTreeFactory → _TreeFactory throughout codebase - Add shapix.optree module (Tree backed by optree) - Add Tree to shapix.jax (backed by jax.tree_util) - _TreeFactory now accepts a get_ops callable for backend parameterization - shapix.Tree auto-detects backend (optree → jax fallback) - Modernize to Python 3.12+ syntax: PEP 695 generics (class Tree[_T], def check[**P, R], type DimSpec = ...), remove TypeVar/ParamSpec/Union - Fix claw.py private import (beartype._conf → beartype.BeartypeConf) - Add multiple variadic validation (F32[~N, ~C] raises TypeError) - Professionalize pyproject.toml (description, license, classifiers, URLs) - Clean up ruff.toml (broader selects, per-file test ignores) - CI: test/compat jobs now need lint+typecheck to pass first - Add edge case tests for Tree repr, factory errors, backend modules - Update README, notebook, and CLAUDE.md for Tree rename + backend imports
1 parent 90a4c82 commit 8e6c5ae

File tree

22 files changed

+1561
-895
lines changed

22 files changed

+1561
-895
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ jobs:
4242
- run: uv run pytest tests/test_typecheck.py -v
4343

4444
test:
45+
needs: [ruff, codespell, typecheck]
4546
runs-on: ubuntu-latest
4647
strategy:
4748
fail-fast: false
@@ -56,6 +57,7 @@ jobs:
5657
TOX_PYTHON: python${{ matrix.python }}
5758

5859
numpy-compat:
60+
needs: [ruff, codespell, typecheck]
5961
runs-on: ubuntu-latest
6062
strategy:
6163
fail-fast: false
@@ -68,6 +70,7 @@ jobs:
6870
- run: uv run tox -e "${{ matrix.version }}"
6971

7072
jax-compat:
73+
needs: [ruff, codespell, typecheck]
7174
runs-on: ubuntu-latest
7275
strategy:
7376
fail-fast: false
@@ -80,6 +83,7 @@ jobs:
8083
- run: uv run tox -e "${{ matrix.version }}"
8184

8285
torch-compat:
86+
needs: [ruff, codespell, typecheck]
8387
runs-on: ubuntu-latest
8488
strategy:
8589
fail-fast: false
@@ -92,6 +96,7 @@ jobs:
9296
- run: uv run tox -e "${{ matrix.version }}"
9397

9498
beartype-compat:
99+
needs: [ruff, codespell, typecheck]
95100
runs-on: ubuntu-latest
96101
strategy:
97102
fail-fast: false
@@ -105,6 +110,7 @@ jobs:
105110
- run: uv run tox -e "${{ matrix.version }}${{ matrix.suffix }}"
106111

107112
optree-compat:
113+
needs: [ruff, codespell, typecheck]
108114
runs-on: ubuntu-latest
109115
strategy:
110116
fail-fast: false

CLAUDE.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ src/shapix/
1313
├── _array_types.py # Array type factory → Annotated[T, Is[checker]]
1414
├── _dimensions.py # Dimension symbols (N, C, ~B, +N, __, etc.)
1515
├── _decorator.py # Optional @shapix.check + check_context
16+
├── _tree.py # Tree annotations with leaf-type and structure checking
1617
├── numpy.py # NumPy: F32, I64, F32Like, ArrayLike, etc.
17-
├── jax.py # JAX: F32, BF16, etc.
18+
├── jax.py # JAX: F32, BF16, Tree (jax.tree_util), etc.
1819
├── torch.py # PyTorch: F32, BF16, etc.
20+
├── optree.py # Tree backed by optree
1921
└── claw.py # Import hook wrapping beartype.claw
2022
```
2123

README.md

Lines changed: 82 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Shapix has one dependency: [beartype](https://github.com/beartype/beartype). Ins
3636
pip install shapix numpy # NumPy
3737
pip install shapix torch # PyTorch
3838
pip install shapix jax # JAX
39-
pip install shapix numpy optree # NumPy + PyTree support
39+
pip install shapix numpy optree # NumPy + tree support (optree or jax)
4040
```
4141

4242
## Quick start
@@ -103,7 +103,6 @@ Bind to a size on first occurrence and enforce consistency on subsequent ones.
103103
| `H` | Height |
104104
| `W` | Width |
105105
| `L` | Sequence length |
106-
| `T` | Tree structure |
107106
| `P` | Points / parameters |
108107

109108
```python
@@ -248,6 +247,7 @@ else:
248247
| `+` | Broadcastable | `+N` | Size 1 always OK |
249248
| `__` | Anonymous | `__` | Match any, no binding |
250249
| `~__` | Anonymous variadic | `~__` | Zero or more, no binding |
250+
| `...` | Ellipsis (alias) | `...` | Same as `~__` |
251251
| arithmetic | Symbolic | `N + 1` | Expression |
252252

253253
## Array types
@@ -324,18 +324,26 @@ BF16_OR_F32 = DtypeSpec("BF16orF32", frozenset({"bfloat16", "float32"}))
324324
MixedArray = make_array_type(np.ndarray, BF16_OR_F32)
325325
```
326326

327-
## PyTree annotations
327+
## Tree annotations
328328

329-
PyTree annotations validate all leaves in a nested structure (dicts, lists, tuples, namedtuples). Requires `optree` (`pip install optree`).
329+
Tree annotations validate all leaves in a nested structure (dicts, lists, tuples, namedtuples). Requires `optree` or `jax` for tree traversal.
330+
331+
**Three ways to import `Tree`:**
332+
333+
```python
334+
from shapix import Tree # auto-detect: tries optree, falls back to jax
335+
from shapix.optree import Tree # explicitly use optree
336+
from shapix.jax import Tree # explicitly use jax.tree_util
337+
```
330338

331339
### Basic leaf checking
332340

333341
```python
334-
from shapix import PyTree, T, S, N, C
342+
from shapix import Tree, T, S, N, C
335343
from shapix.numpy import F32
336344

337345
@beartype
338-
def process(data: PyTree[F32[N, C]]) -> PyTree[F32[N, C]]:
346+
def process(data: Tree[F32[N, C]]) -> Tree[F32[N, C]]:
339347
...
340348

341349
# All leaves must be F32 arrays with consistent N and C
@@ -349,29 +357,35 @@ Named structure symbols (`T`, `S`) enforce that multiple arguments share identic
349357

350358
```python
351359
@beartype
352-
def add_trees(x: PyTree[F32[N], T], y: PyTree[F32[N], T]) -> PyTree[F32[N]]:
360+
def add_trees(x: Tree[F32[N], T], y: Tree[F32[N], T]) -> Tree[F32[N]]:
353361
...
354362

355363
add_trees({"a": x1, "b": x2}, {"a": y1, "b": y2}) # OK — same structure
356364
add_trees({"a": x1}, [y1, y2]) # Raises — different structure
357365
```
358366

359-
### Composite structures
367+
### Multi-level structure matching
360368

361-
Use `S[T]` for nested structure matching — outer structure S, inner structure T:
369+
Structure names are listed left-to-right from outer to inner. Without `...`, each name except the last captures ONE level; the last captures the full remaining structure. Trailing `...` makes all names one-level-only with inner levels unchecked. Leading `...` matches names from the bottom up.
362370

363371
```python
364-
def f(x: PyTree[int, T], y: PyTree[int, S], z: PyTree[int, S[T]]): ...
365-
```
372+
# T = full tree structure (all levels)
373+
def f(x: Tree[F32[N], T], y: Tree[F32[N], T]): ...
366374

367-
### Prefix and suffix wildcards
375+
# T = top-level only, subtrees are arbitrary
376+
def f(x: Tree[F32[N], T, ...], y: Tree[F32[N], T, ...]): ...
368377

369-
```python
370-
# T[...] — top-level structure matches T, subtrees are arbitrary
371-
def f(x: PyTree[F32[N], T[...]], y: PyTree[F32[N], T[...]]): ...
378+
# T = bottom-level only (leaf-adjacent container)
379+
def f(x: Tree[F32[N], ..., T], y: Tree[F32[N], ..., T]): ...
380+
381+
# T = top level, S = full remaining structure below
382+
def f(x: Tree[int, T], y: Tree[int, S], z: Tree[int, T, S]): ...
372383

373-
# ..., T — full structure matches T
374-
def f(x: PyTree[F32[N], ..., T], y: PyTree[F32[N], ..., T]): ...
384+
# T = top, S = next, inner levels unchecked
385+
def f(x: Tree[F32[N], T, S, ...]): ...
386+
387+
# S = bottom, T = second-from-bottom
388+
def f(x: Tree[F32[N], ..., T, S]): ...
375389
```
376390

377391
### Custom structure symbols
@@ -385,7 +399,7 @@ Params = Structure("Params")
385399
State = Structure("State")
386400

387401
@beartype
388-
def train(params: PyTree[F32[N], Params], state: PyTree[I64[N], State]): ...
402+
def train(params: Tree[F32[N], Params], state: Tree[I64[N], State]): ...
389403
```
390404

391405
## Advanced usage
@@ -412,36 +426,74 @@ A **dimension memo** maps dimension names to sizes (e.g., `N→4`, `C→3`). Eac
412426

413427
This happens **automatically** with `@beartype`. Shapix detects the beartype wrapper frame via `sys._getframe()` and associates a memo with it. No extra decorator needed.
414428

415-
### `@shapix.check` (optional)
429+
### `@shapix.check` — explicit memo management
430+
431+
To understand `@shapix.check`, you need to understand the problem it solves.
416432

417-
`@shapix.check` provides **explicit** memo management. It's useful in two scenarios:
433+
**The problem: sharing state across parameter checks.** When beartype validates `f(x, y)`, it checks `x` and `y` independently — it calls the `Is[...]` validator once per parameter. But shapix needs all those validators to share the same dimension memo, so that `N=4` bound by `x` is enforced on `y`. Something has to connect them.
418434

419-
1. **Combining memo with custom `BeartypeConf`** — a single decorator instead of stacking two:
435+
**The automatic approach** (no extra decorator needed) is frame-based detection. Shapix walks up the call stack with `sys._getframe()` to find the beartype wrapper frame. Since all parameter checks within one `f(x, y)` call share the same wrapper frame, shapix can key a memo to it. This just works:
420436

421437
```python
422-
from beartype import BeartypeConf
438+
@beartype # Shapix auto-detects this frame — nothing else needed
439+
def f(x: F32[N, C], y: F32[N, C]) -> F32[N, C]: ...
440+
```
423441

424-
@shapix.check(conf=BeartypeConf())
425-
def f(x: F32[N, C]) -> F32[N, C]: ...
442+
**`@shapix.check`** takes a different approach: instead of detecting the frame, it explicitly pushes a fresh memo onto a stack before the call and pops it after. All validators see this explicit memo first (it takes priority over frame detection):
443+
444+
```python
445+
@shapix.check # Pushes memo before call, pops after
446+
@beartype # Validates parameters using that memo
447+
def f(x: F32[N, C], y: F32[N, C]) -> F32[N, C]: ...
426448
```
427449

428-
2. **Exotic call stacks** where frame-based detection is unreliable (deep decorator chains, recursive wrappers).
450+
Both approaches produce identical results in normal usage. You only need `@shapix.check` in specific situations:
451+
452+
#### 1. Extra decorators between beartype and the call site
429453

430-
For normal usage, just use `@beartype` — it works out of the box.
454+
Frame-based detection counts a fixed number of frames up from the validator. If something adds extra frames between beartype's wrapper and the actual function call, the detection can land on the wrong frame:
455+
456+
```python
457+
# This works — beartype is the outermost wrapper, frame detection is stable
458+
@beartype
459+
def f(x: F32[N, C], y: F32[N, C]) -> F32[N, C]: ...
460+
461+
# This might not — a middleware decorator adds extra frames
462+
@some_middleware # Adds frames between beartype's wrapper and the caller
463+
@beartype
464+
def f(x: F32[N, C], y: F32[N, C]) -> F32[N, C]: ...
465+
466+
# Fix: @shapix.check bypasses frame detection entirely
467+
@some_middleware
468+
@shapix.check
469+
@beartype
470+
def f(x: F32[N, C], y: F32[N, C]) -> F32[N, C]: ...
471+
```
472+
473+
#### 2. Defensive coding
474+
475+
If you want a guarantee that cross-argument checking works regardless of how your code is called (by test runners, async frameworks, deep middleware stacks), `@shapix.check` removes all dependence on call-stack structure.
476+
477+
**When you don't need it:** If you're using plain `@beartype` and your tests pass, the frame-based detection is working. Most applications never need `@shapix.check`.
431478

432479
### Manual checks with `check_context`
433480

434-
For `isinstance`-style checks outside of decorated functions, use `check_context` with beartype's `is_bearable`:
481+
For `isinstance`-style checks outside of decorated functions, use `check_context` with beartype's `is_bearable`. Without it, each check gets an independent memo — dimensions aren't cross-checked:
435482

436483
```python
437484
from beartype.door import is_bearable
438485
import shapix
439486
from shapix import N, C
440487
from shapix.numpy import F32
441488

489+
# Without check_context — each check is independent, N is NOT cross-checked
490+
is_bearable(x, F32[N, C]) # Binds N=4 in a temporary memo (discarded)
491+
is_bearable(y, F32[N, C]) # Binds N=999 in a new memo — no error!
492+
493+
# With check_context — checks share a memo, N IS cross-checked
442494
with shapix.check_context():
443-
assert is_bearable(x, F32[N, C]) # Binds N and C
444-
assert is_bearable(y, F32[N, C]) # Must match same N, C
495+
assert is_bearable(x, F32[N, C]) # Binds N=4
496+
assert is_bearable(y, F32[N, C]) # Checks N=4 — raises if y has N=999
445497
```
446498

447499
## How it works
@@ -520,7 +572,7 @@ def f(x: F32[~B, C]) -> F32[~B, C]: # type: ignore[reportInvalidTypeForm]
520572
| BeartypeConf | Not supported (decorator conflict) | Fully supported |
521573
| Type checker | Metaclass magic (confuses pyright) | `Annotated` aliases (clean) |
522574
| Backends | NumPy, JAX | NumPy, JAX, PyTorch |
523-
| PyTree | Built-in with structure binding | Built-in with structure binding (via optree) |
575+
| Tree | Built-in with structure binding | Built-in with structure binding (via optree) |
524576
| Dependencies | jaxtyping + beartype | beartype only |
525577
| Custom decorator | Required | Not required |
526578

0 commit comments

Comments
 (0)