Skip to content

Commit 7a02c67

Browse files
committed
Backend-aware Like types, BF16Like, D/K dims, tox matrix trim
- Thread `asarray` parameter through _ArrayLikeChecker so JAX/PyTorch Like types use jnp.asarray / torch.as_tensor on the slow path, accepting __jax_array__ and __torch_function__ protocol objects - Add exact-match shortcut in _check_dtype for non-numpy dtypes (bfloat16) - Add BF16Like to jax.py and torch.py (was missing from Like type set) - Add D (embedding dim) and K (num heads) dimension symbols for transformers - Export __version__ from shapix.__init__ - Trim tox matrix from ~252 to ~35 envs (oldest+newest strategy) - Lower beartype minimum to >=0.20 to match tox matrix - Defensive error in _shape.py fallthrough instead of silent pass - Document `from __future__ import annotations` compatibility in README
1 parent 68f999e commit 7a02c67

File tree

13 files changed

+413
-90
lines changed

13 files changed

+413
-90
lines changed

README.md

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ Bind to a size on first occurrence and enforce consistency on subsequent ones.
110110
| `N` | Batch size, count |
111111
| `B` | Batch |
112112
| `C` | Channels |
113+
| `D` | Embedding dimension |
114+
| `K` | Number of heads |
113115
| `H` | Height |
114116
| `W` | Width |
115117
| `L` | Sequence length |
@@ -299,15 +301,15 @@ from shapix.numpy import F32, I64, Shaped # and many more
299301
from shapix.jax import F32, BF16
300302
```
301303

302-
Same type names as NumPy, plus `BF16`. Base type is `jax.Array`. Also exports `Like` types and `Tree`.
304+
Same type names as NumPy, plus `BF16` and `BF16Like`. Base type is `jax.Array`. Also exports `Like` types and `Tree`.
303305

304306
### PyTorch
305307

306308
```python
307309
from shapix.torch import F32, BF16
308310
```
309311

310-
Same type names as NumPy, plus `BF16`. Base type is `torch.Tensor`. Also exports `Like` types.
312+
Same type names as NumPy, plus `BF16` and `BF16Like`. Base type is `torch.Tensor`. Also exports `Like` types.
311313

312314
### Endianness variants
313315

@@ -564,6 +566,29 @@ def train(params: Tree[F32[N], Params], state: Tree[I64[N], State]): ...
564566

565567
## Advanced usage
566568

569+
### `from __future__ import annotations` (PEP 563)
570+
571+
Shapix is fully compatible with `from __future__ import annotations`. The library itself uses it in every source file.
572+
573+
The one rule: **every dimension symbol used in an annotation must be imported in the module scope.** This is true with or without the future import — the difference is only the error you get if you forget:
574+
575+
```python
576+
from __future__ import annotations
577+
from shapix import C # B is NOT imported
578+
from shapix.numpy import F32
579+
580+
@beartype
581+
def f(x: F32[~B, C]): ... # BeartypeDecorHintForwardRefException — B is not in scope
582+
```
583+
584+
Fix: import `B`:
585+
586+
```python
587+
from shapix import B, C
588+
```
589+
590+
This applies to all dimension symbols — named (`N`, `B`), custom (`Vocab = Dimension("Vocab")`), and any symbol used with operators (`~B`, `+N`). The operators (`~`, `+`) are evaluated on the imported object, so the base symbol must be available.
591+
567592
### Package-wide instrumentation with `beartype.claw`
568593

569594
Instead of decorating each function with `@beartype`, you can instrument an entire package:
@@ -753,7 +778,7 @@ def f(x: F32[~B, C]) -> F32[~B, C]: # type: ignore[reportInvalidTypeForm]
753778

754779
### Dimension symbols (`shapix`)
755780

756-
`N`, `B`, `C`, `H`, `W`, `L`, `P` — named dimensions
781+
`N`, `B`, `C`, `D`, `K`, `H`, `W`, `L`, `P` — named dimensions
757782
`__` — anonymous dimension
758783
`Scalar` — scalar (zero dimensions)
759784
`T`, `S` — tree structure symbols
@@ -775,7 +800,7 @@ def f(x: F32[~B, C]) -> F32[~B, C]: # type: ignore[reportInvalidTypeForm]
775800

776801
### JAX/PyTorch (`shapix.jax`, `shapix.torch`)
777802

778-
Same array types as NumPy, plus `BF16`. Both export `Like` types, `ScalarLike` types (re-exported from numpy), and `make_scalar_like_type`. JAX also exports `Tree`.
803+
Same array types as NumPy, plus `BF16` and `BF16Like`. Both export `Like` types, `ScalarLike` types (re-exported from numpy), and `make_scalar_like_type`. JAX also exports `Tree`.
779804

780805
### Factories (`shapix`)
781806

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ readme = "README.md"
66
license = "MIT"
77
authors = [{ name = "acecchini", email = "ale.cecchini.valette@gmail.com" }]
88
requires-python = ">=3.12"
9-
dependencies = ["beartype>=0.22.9"]
9+
dependencies = ["beartype>=0.20"]
1010
keywords = [
1111
"type-checking",
1212
"shape",

src/shapix/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def conv(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
2020
Exports
2121
-------
2222
Dimension symbols
23-
``B``, ``N``, ``P``, ``L``, ``C``, ``H``, ``W`` — named dimensions.
23+
``B``, ``N``, ``P``, ``L``, ``C``, ``D``, ``K``, ``H``, ``W`` — named dimensions.
2424
``__`` — anonymous (match any single dim, no binding).
2525
``Scalar`` — scalar (no dimensions).
2626
@@ -55,6 +55,10 @@ def conv(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
5555
``is_bearable()`` checks.
5656
"""
5757

58+
from importlib.metadata import version
59+
60+
__version__ = version("shapix")
61+
5862
from ._array_types import make_array_like_type as make_array_like_type
5963
from ._array_types import make_array_type as make_array_type
6064
from ._decorator import check as check
@@ -65,8 +69,10 @@ def conv(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
6569
from ._dimensions import __ as __
6670
from ._dimensions import B as B
6771
from ._dimensions import C as C
72+
from ._dimensions import D as D
6873
from ._dimensions import Dimension as Dimension
6974
from ._dimensions import H as H
75+
from ._dimensions import K as K
7076
from ._dimensions import L as L
7177
from ._dimensions import N as N
7278
from ._dimensions import P as P

src/shapix/_array_types.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,25 @@ class _ArrayLikeChecker:
174174
175175
Handles arrays (objects with ``.shape`` + ``.dtype``), scalars, nested
176176
sequences, ``__array__`` protocol objects, and buffer protocol objects
177-
by converting to a numpy array and checking dtype + shape.
177+
by converting to an array and checking dtype + shape.
178178
179179
The *casting* parameter controls dtype strictness using numpy casting rules:
180180
``no < equiv < safe < same_kind < unsafe``.
181+
182+
The optional *asarray* callable allows backends to provide their own
183+
conversion function (e.g. ``jnp.asarray``, ``torch.as_tensor``) so that
184+
objects implementing backend-specific protocols (``__jax_array__``,
185+
``__torch_function__``) are accepted.
181186
"""
182187

183-
__slots__ = ("_dtype_spec", "_shape_spec", "_casting", "_repr", "_fail_obj")
188+
__slots__ = (
189+
"_dtype_spec",
190+
"_shape_spec",
191+
"_casting",
192+
"_asarray",
193+
"_repr",
194+
"_fail_obj",
195+
)
184196

185197
def __init__(
186198
self,
@@ -189,10 +201,12 @@ def __init__(
189201
*,
190202
casting: str,
191203
name: str,
204+
asarray: object | None = None,
192205
) -> None:
193206
self._dtype_spec = dtype_spec
194207
self._shape_spec = shape_spec
195208
self._casting = casting
209+
self._asarray = asarray
196210
self._fail_obj: object | None = None
197211

198212
dims = ", ".join(repr(d) for d in shape_spec)
@@ -215,20 +229,34 @@ def __call__(self, obj: object) -> bool:
215229
self._fail_obj = obj
216230
return result
217231

218-
# Slow path: convert scalar / sequence / __array__ / buffer to numpy array
219-
try:
220-
import numpy as np
221-
222-
arr = np.asarray(obj)
223-
except (TypeError, ValueError):
232+
# Slow path: convert scalar / sequence / protocol object to array.
233+
# Try the backend-specific converter first (jnp.asarray, torch.as_tensor),
234+
# then fall back to np.asarray for scalars and nested sequences.
235+
arr = self._convert(obj)
236+
if arr is None:
224237
self._fail_obj = obj
225238
return False
226239

227-
result = self._check(arr, tuple(arr.shape), memo)
240+
result = self._check(arr, tuple(arr.shape), memo) # type: ignore[union-attr]
228241
if not result:
229242
self._fail_obj = obj
230243
return result
231244

245+
def _convert(self, obj: object) -> object | None:
246+
"""Convert *obj* to an array with ``.shape`` and ``.dtype``, or None."""
247+
if self._asarray is not None:
248+
try:
249+
return self._asarray(obj) # type: ignore[operator]
250+
except Exception: # noqa: BLE001
251+
pass # fall through to numpy
252+
253+
try:
254+
import numpy as np
255+
256+
return np.asarray(obj)
257+
except (TypeError, ValueError):
258+
return None
259+
232260
def _check(self, obj: object, shape: tuple[int, ...], memo: ShapeMemo) -> bool:
233261
"""Validate dtype (with casting rules) then shape (with memo)."""
234262
if not self._check_dtype(obj):
@@ -254,6 +282,11 @@ def _check_dtype(self, obj: object) -> bool:
254282
if "*" in self._dtype_spec.allowed:
255283
return self._dtype_spec._check_byteorder(obj)
256284

285+
# Exact string match always passes (handles non-numpy dtypes like bfloat16
286+
# where np.can_cast would raise TypeError for an unknown dtype string).
287+
if source in self._dtype_spec.allowed:
288+
return self._dtype_spec._check_byteorder(obj)
289+
257290
import numpy as np
258291

259292
for target in self._dtype_spec.allowed:
@@ -281,11 +314,19 @@ class _ArrayLikeFactory:
281314
sequences, arrays) with dtype casting awareness.
282315
"""
283316

284-
__slots__ = ("_dtype_spec", "_casting", "__name__")
317+
__slots__ = ("_dtype_spec", "_casting", "_asarray", "__name__")
285318

286-
def __init__(self, dtype_spec: DtypeSpec, *, casting: str, name: str) -> None:
319+
def __init__(
320+
self,
321+
dtype_spec: DtypeSpec,
322+
*,
323+
casting: str,
324+
name: str,
325+
asarray: object | None = None,
326+
) -> None:
287327
self._dtype_spec = dtype_spec
288328
self._casting = casting
329+
self._asarray = asarray
289330
self.__name__ = name
290331

291332
def __getitem__(self, dims: object) -> type:
@@ -294,7 +335,11 @@ def __getitem__(self, dims: object) -> type:
294335

295336
shape_spec = _to_shape_spec(dims)
296337
checker = _ArrayLikeChecker(
297-
self._dtype_spec, shape_spec, casting=self._casting, name=self.__name__
338+
self._dtype_spec,
339+
shape_spec,
340+
casting=self._casting,
341+
name=self.__name__,
342+
asarray=self._asarray,
298343
)
299344
return Annotated[object, Is[checker]] # type: ignore[return-value]
300345

@@ -303,7 +348,11 @@ def __repr__(self) -> str:
303348

304349

305350
def make_array_like_type(
306-
dtype_spec: DtypeSpec, *, casting: str = "same_kind", name: str = "ArrayLike"
351+
dtype_spec: DtypeSpec,
352+
*,
353+
casting: str = "same_kind",
354+
name: str = "ArrayLike",
355+
asarray: object | None = None,
307356
) -> _ArrayLikeFactory:
308357
"""Create a subscriptable array-like type factory.
309358
@@ -317,6 +366,11 @@ def make_array_like_type(
317366
compatibility is checked. Default ``"same_kind"``.
318367
name:
319368
Human-readable name for error messages.
369+
asarray:
370+
Optional callable ``(obj) -> array`` for backend-specific conversion.
371+
When provided, it is tried before ``np.asarray`` on the slow path
372+
(objects without ``.shape`` / ``.dtype``). Use this to support
373+
protocols like ``__jax_array__`` or ``__torch_function__``.
320374
321375
Returns
322376
-------
@@ -334,7 +388,7 @@ def make_array_like_type(
334388
if casting not in _VALID_CASTINGS:
335389
msg = f"Invalid casting {casting!r}, must be one of {sorted(_VALID_CASTINGS)}"
336390
raise ValueError(msg)
337-
return _ArrayLikeFactory(dtype_spec, casting=casting, name=name)
391+
return _ArrayLikeFactory(dtype_spec, casting=casting, name=name, asarray=asarray)
338392

339393

340394
# ---------------------------------------------------------------------------

src/shapix/_dimensions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def flatten(x: F32[N, C]) -> F32[N * C]: ...
5151
"P",
5252
"L",
5353
"C",
54+
"D",
55+
"K",
5456
"H",
5557
"W",
5658
# Anonymous
@@ -195,6 +197,8 @@ def _dim_spec(self) -> DimSpec | None:
195197
P: Dimension
196198
L: Dimension
197199
C: Dimension
200+
D: Dimension
201+
K: Dimension
198202
H: Dimension
199203
W: Dimension
200204
__: Dimension
@@ -206,6 +210,8 @@ def _dim_spec(self) -> DimSpec | None:
206210
P = Dimension("P")
207211
L = Dimension("L")
208212
C = Dimension("C")
213+
D = Dimension("D")
214+
K = Dimension("K")
209215
H = Dimension("H")
210216
W = Dimension("W")
211217

src/shapix/_shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def _check_one(dim: DimSpec, size: int, memo: ShapeMemo) -> str:
234234
return f"dimension '{dim.expr}' evaluated to {expected} but got {size}"
235235
return ""
236236

237-
return ""
237+
return f"internal error: unrecognized dim spec {dim!r}"
238238

239239

240240
def _check_variadic(dim: VariadicDim, shape: tuple[int, ...], memo: ShapeMemo) -> str:

0 commit comments

Comments
 (0)