Skip to content

Commit ceca1b3

Browse files
committed
Refine runtime value shape expression API
1 parent a1558da commit ceca1b3

File tree

10 files changed

+57
-31
lines changed

10 files changed

+57
-31
lines changed

README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ and numeric literals. Attribute access and function calls are rejected.
159159

160160
### Runtime value dimensions
161161

162-
Use `Value["expr"]` when a shape depends on a runtime parameter or `self`
162+
Use `Value("expr")` when a shape depends on a runtime parameter or `self`
163163
attribute rather than a previously bound dimension:
164164

165165
```python
@@ -168,19 +168,22 @@ from shapix import Value
168168
from shapix.numpy import F32
169169
import numpy as np
170170

171+
Size = Value("size")
172+
WidthPlus3 = Value("self.width + 3")
173+
171174
@beartype
172-
def full(size: int) -> F32[Value["size"]]:
175+
def full(size: int) -> F32[Size]:
173176
return np.full((size,), 1.0, dtype=np.float32)
174177

175178
class SomeClass:
176179
width = 5
177180

178181
@beartype
179-
def full(self) -> F32[Value["self.width + 3"]]:
182+
def full(self) -> F32[WidthPlus3]:
180183
return np.full((self.width + 3,), 1.0, dtype=np.float32)
181184
```
182185

183-
`Value[...]` uses a restricted arithmetic grammar as well. It allows names,
186+
`Value(...)` uses a restricted arithmetic grammar as well. It allows names,
184187
attribute access, numeric literals, and arithmetic operators, but rejects calls,
185188
indexing, and other arbitrary Python expressions.
186189

src/shapix/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def conv(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
2323
``B``, ``N``, ``P``, ``L``, ``C``, ``D``, ``K``, ``H``, ``W`` — named dimensions.
2424
``__`` — anonymous (match any single dim, no binding).
2525
``Scalar`` — scalar (no dimensions).
26-
``Value["expr"]`` — explicit runtime value expression for shape dims.
26+
``Value("expr")`` — explicit runtime value expression for shape dims.
2727
2828
Tree structure symbols
2929
``T``, ``S`` — named tree structure symbols.

src/shapix/_array_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def _to_shape_spec(dims: tuple[object, ...]) -> tuple[DimSpec, ...]:
426426
else:
427427
msg = (
428428
"Invalid shape token "
429-
f"{d!r}; expected int, Ellipsis, Dimension, Value[...], or a DimSpec"
429+
f"{d!r}; expected int, Ellipsis, Dimension, Value(...), or a DimSpec"
430430
)
431431
raise TypeError(msg)
432432

src/shapix/_dimensions.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def flatten(x: F32[N, C]) -> F32[N * C]: ...
2828
Plain ``int`` values (e.g. ``3``) are also accepted as fixed dimension sizes
2929
when subscripting array types.
3030
31-
Use ``Value["expr"]`` for dimensions that depend on runtime parameters or
31+
Use ``Value("expr")`` for dimensions that depend on runtime parameters or
3232
``self`` attributes rather than previously bound shape names.
3333
"""
3434

@@ -190,11 +190,14 @@ def _dim_spec(self) -> DimSpec | None:
190190

191191

192192
class _ValueExpr:
193-
"""Runtime value expression used by ``Value[...]`` in shape subscripts."""
193+
"""Runtime value expression used by ``Value("...")`` in shape subscripts."""
194194

195195
__slots__ = ("expr", "broadcastable")
196196

197-
def __init__(self, expr: str, *, broadcastable: bool = False) -> None:
197+
def __init__(self, expr: object, *, broadcastable: bool = False) -> None:
198+
if not isinstance(expr, str):
199+
msg = "Value(...) expects a string expression"
200+
raise TypeError(msg)
198201
self.expr = expr
199202
self.broadcastable = broadcastable
200203

@@ -209,17 +212,20 @@ def _dim_spec(self) -> ValueDim:
209212

210213
def __repr__(self) -> str:
211214
prefix = "+" if self.broadcastable else ""
212-
return f'{prefix}Value["{self.expr}"]'
215+
return f'{prefix}Value("{self.expr}")'
213216

214217

215218
# ---------------------------------------------------------------------------
216219
# Pre-defined dimension symbols
217220
# ---------------------------------------------------------------------------
218221

219222
if tp.TYPE_CHECKING:
220-
# Declared as ``Dimension`` so type checkers see the full operator set
221-
# (``__add__``, ``__invert__``, ``__pos__``, …) and ``F32[N, C]`` subscripts.
223+
222224
class Value:
225+
def __new__(cls, expr: str) -> Dimension: ...
226+
227+
def __pos__(self) -> Dimension: ...
228+
223229
@classmethod
224230
def __class_getitem__(cls, expr: str) -> Dimension: ...
225231

@@ -236,20 +242,17 @@ def __class_getitem__(cls, expr: str) -> Dimension: ...
236242
__: Dimension
237243
else:
238244

239-
class Value:
245+
class Value(_ValueExpr):
240246
"""Explicit runtime value expression for shape annotations.
241247
242-
Use ``Value["size"]`` or ``Value["self.some_value + 3"]`` when a shape
248+
Use ``Value("size")`` or ``Value("self.some_value + 3")`` when a shape
243249
depends on a runtime parameter rather than a previously bound dimension.
244250
"""
245251

246252
__slots__ = ()
247253

248-
def __class_getitem__(cls, expr: object) -> _ValueExpr: # type: ignore[misc]
249-
if not isinstance(expr, str):
250-
msg = "Value[...] expects a string expression"
251-
raise TypeError(msg)
252-
return _ValueExpr(expr)
254+
def __class_getitem__(cls, expr: object) -> Value: # type: ignore[misc]
255+
return cls(expr)
253256

254257
# Common named dimensions
255258
Scalar = Dimension("")

src/shapix/_memo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def get_memo(_depth: int = 2) -> ShapeMemo:
187187

188188

189189
def get_scope(_depth: int = 2) -> dict[str, object]:
190-
"""Return the current runtime scope for ``Value[...]`` expressions."""
190+
"""Return the current runtime scope for ``Value(...)`` expressions."""
191191
explicit = _get_explicit_scope_stack()
192192
if explicit and explicit[-1] is not None:
193193
return explicit[-1]
@@ -201,7 +201,7 @@ def get_scope(_depth: int = 2) -> dict[str, object]:
201201

202202
# beartype wrappers expose runtime arguments as ``args`` / ``kwargs`` plus
203203
# the wrapped function object. Rebind those to parameter names so
204-
# ``Value["size"]`` and ``Value["self.attr"]`` work for both param and
204+
# ``Value("size")`` and ``Value("self.attr")`` work for both param and
205205
# return validation.
206206
fn = locals_map.get("__beartype_func")
207207
args = locals_map.get("args")

src/shapix/_shape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
- :class:`SymbolicDim` — arithmetic expression evaluated against bound
1414
dimensions (``N+1``, ``H*W``). Optionally broadcastable.
1515
- :class:`ValueDim` — arithmetic expression evaluated against runtime call
16-
scope (``Value["size"]``, ``Value["self.width + 3"]``). Optionally
16+
scope (``Value("size")``, ``Value("self.width + 3")``). Optionally
1717
broadcastable.
1818
- :class:`VariadicDim` — matches zero or more contiguous dims and binds the
1919
matched sub-shape (``~spatial``). Optionally broadcastable.
@@ -96,7 +96,7 @@ class ValueDim:
9696

9797
def __repr__(self) -> str:
9898
prefix = "+" if self.broadcastable else ""
99-
return f'{prefix}Value["{self.expr}"]'
99+
return f'{prefix}Value("{self.expr}")'
100100

101101

102102
@dataclass(frozen=True, slots=True)

src/shapix/_tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __call__(self, obj: object) -> bool:
105105
from ._memo import get_memo, get_scope, pop_memo, push_memo
106106

107107
# Bridge memo + runtime scope so leaf checks reuse the caller's bindings
108-
# and can resolve ``Value[...]`` expressions against the same parameters.
108+
# and can resolve ``Value(...)`` expressions against the same parameters.
109109
memo = get_memo(_depth=3)
110110
scope = get_scope(_depth=3)
111111
push_memo(memo, scope=scope)

tests/test_dimensions.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import pytest
6+
57
from shapix._dimensions import Dimension, Value
68
from shapix._shape import (
79
ANONYMOUS,
@@ -222,14 +224,23 @@ def test_dim_plus_dim(self) -> None:
222224

223225

224226
class TestValueExpressions:
227+
def test_value_expr_requires_string(self) -> None:
228+
with pytest.raises(TypeError, match=r"Value\(\.\.\.\) expects a string expression"):
229+
Value(1) # type: ignore[arg-type]
230+
225231
def test_value_expr(self) -> None:
226-
spec = Value["size"]._dim_spec # noqa: SLF001
232+
spec = Value("size")._dim_spec # noqa: SLF001
227233
assert isinstance(spec, ValueDim)
228234
assert spec.expr == "size"
229235
assert spec.broadcastable is False
230236

231237
def test_broadcastable_value_expr(self) -> None:
232-
spec = (+Value["size"])._dim_spec # noqa: SLF001
238+
spec = (+Value("size"))._dim_spec # noqa: SLF001
233239
assert isinstance(spec, ValueDim)
234240
assert spec.expr == "size"
235241
assert spec.broadcastable is True
242+
243+
def test_value_getitem_alias(self) -> None:
244+
spec = Value["size"]._dim_spec # noqa: SLF001
245+
assert isinstance(spec, ValueDim)
246+
assert spec.expr == "size"

tests/test_numpy.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,10 @@ def f(x: F32[N, C]) -> F32[N * C]:
418418

419419
class TestValueExpressions:
420420
def test_value_expr_from_argument(self) -> None:
421+
Size = Value("size")
422+
421423
@beartype
422-
def f(size: int, x: F32[Value["size"]]) -> F32[Value["size"]]: # noqa: F821
424+
def f(size: int, x: F32[Size]) -> F32[Size]:
423425
return x
424426

425427
arr = np.ones(4, dtype=np.float32)
@@ -430,16 +432,19 @@ def f(size: int, x: F32[Value["size"]]) -> F32[Value["size"]]: # noqa: F821
430432
def test_value_expr_from_self_attribute(self) -> None:
431433
class SomeClass:
432434
size = 5
435+
FullSize = Value("self.size + 3")
433436

434437
@beartype
435-
def full(self) -> F32[Value["self.size + 3"]]: # noqa: F821
438+
def full(self) -> F32[FullSize]:
436439
return np.ones(self.size + 3, dtype=np.float32)
437440

438441
assert SomeClass().full().shape == (8,)
439442

440443
def test_value_expr_can_mix_scope_and_bound_dims(self) -> None:
444+
Padded = Value("N + pad")
445+
441446
@beartype
442-
def f(x: F32[N], pad: int) -> F32[Value["N + pad"]]: # noqa: F821
447+
def f(x: F32[N], pad: int) -> F32[Padded]:
443448
return np.ones(x.shape[0] + pad, dtype=np.float32)
444449

445450
assert f(np.ones(4, dtype=np.float32), 2).shape == (6,)

tests/typing/check_annotations.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,16 +192,20 @@ def double_channels(x: F32[N, C]) -> F32[N, C * 2]:
192192
return np.concatenate([x, x], axis=1).astype(np.float32)
193193

194194

195+
FromSize = Value("size")
196+
FromSelf = Value("self.offset + size")
197+
198+
195199
@beartype
196-
def from_value(size: int) -> F32[Value["size"]]: # noqa: F821
200+
def from_value(size: int) -> F32[FromSize]:
197201
return np.ones(size, dtype=np.float32)
198202

199203

200204
class SomeClass:
201205
offset = 3
202206

203207
@beartype
204-
def from_self(self, size: int) -> F32[Value["self.offset + size"]]: # noqa: F821
208+
def from_self(self, size: int) -> F32[FromSelf]:
205209
return np.ones(self.offset + size, dtype=np.float32)
206210

207211

0 commit comments

Comments
 (0)