Skip to content

Commit 68f999e

Browse files
committed
Production hardening: input validation, DRY memo, import efficiency
- Add ShapeMemo.snapshot()/restore() to eliminate duplicated 10-line save/restore pattern in _StructChecker and _ArrayLikeChecker - Validate `casting` parameter in make_array_like_type and make_scalar_like_type (reject invalid strings early) - Validate `byteorder` parameter in DtypeSpec.__post_init__ - Move XLikeNumpy imports under TYPE_CHECKING in jax.py/torch.py (23 fewer attribute lookups at runtime per backend import) - Tighten _check_fixed_dims signature (remove unused list type) - Add test_coverage_edges.py to numpy tox backend set in conftest.py - Add tests for all new validation paths + snapshot/restore
1 parent eb90aa0 commit 68f999e

File tree

9 files changed

+149
-86
lines changed

9 files changed

+149
-86
lines changed

src/shapix/_array_types.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545

4646
__all__ = ["make_array_type", "make_array_like_type"]
4747

48+
_VALID_CASTINGS = frozenset({"no", "equiv", "safe", "same_kind", "unsafe"})
49+
4850

4951
# ---------------------------------------------------------------------------
5052
# Validator callable (used inside beartype's Is[...])
@@ -90,22 +92,12 @@ def __call__(self, obj: object) -> bool:
9092
return False
9193

9294
memo = get_memo(_depth=3)
93-
94-
# Snapshot memo state so we can restore on failure (avoid polluting
95-
# the memo with partial bindings from a bad argument).
96-
single_snap = memo.single.copy()
97-
variadic_snap = memo.variadic.copy()
98-
structures_snap = memo.structures.copy()
95+
snap = memo.snapshot()
9996

10097
result = check_shape(tuple(shape), self._shape_spec, memo) == ""
10198

10299
if not result:
103-
memo.single.clear()
104-
memo.single.update(single_snap)
105-
memo.variadic.clear()
106-
memo.variadic.update(variadic_snap)
107-
memo.structures.clear()
108-
memo.structures.update(structures_snap)
100+
memo.restore(snap)
109101
self._fail_obj = obj
110102

111103
return result
@@ -242,21 +234,10 @@ def _check(self, obj: object, shape: tuple[int, ...], memo: ShapeMemo) -> bool:
242234
if not self._check_dtype(obj):
243235
return False
244236

245-
# Snapshot memo state so we can restore on failure
246-
single_snap = memo.single.copy()
247-
variadic_snap = memo.variadic.copy()
248-
structures_snap = memo.structures.copy()
249-
237+
snap = memo.snapshot()
250238
result = check_shape(shape, self._shape_spec, memo) == ""
251-
252239
if not result:
253-
memo.single.clear()
254-
memo.single.update(single_snap)
255-
memo.variadic.clear()
256-
memo.variadic.update(variadic_snap)
257-
memo.structures.clear()
258-
memo.structures.update(structures_snap)
259-
240+
memo.restore(snap)
260241
return result
261242

262243
def _check_dtype(self, obj: object) -> bool:
@@ -350,6 +331,9 @@ def make_array_like_type(
350331
F32Like = make_array_like_type(FLOAT32, name="F32Like")
351332
F32Like[N, C, H, W] # → Annotated[object, Is[...]]
352333
"""
334+
if casting not in _VALID_CASTINGS:
335+
msg = f"Invalid casting {casting!r}, must be one of {sorted(_VALID_CASTINGS)}"
336+
raise ValueError(msg)
353337
return _ArrayLikeFactory(dtype_spec, casting=casting, name=name)
354338

355339

src/shapix/_dtypes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from __future__ import annotations
1414

15+
import typing as tp
1516
from dataclasses import dataclass, field
1617

1718
__all__ = [
@@ -159,6 +160,18 @@ class DtypeSpec:
159160
byteorder: str = "any"
160161
_structured: object = field(default=None, repr=False, compare=False)
161162

163+
_VALID_BYTEORDERS: tp.ClassVar[frozenset[str]] = frozenset({
164+
"any",
165+
"little",
166+
"big",
167+
"native",
168+
})
169+
170+
def __post_init__(self) -> None:
171+
if self.byteorder not in self._VALID_BYTEORDERS:
172+
msg = f"Invalid byteorder {self.byteorder!r}, must be one of {sorted(self._VALID_BYTEORDERS)}"
173+
raise ValueError(msg)
174+
162175
def matches(self, obj: object) -> bool:
163176
"""Return ``True`` if *obj*'s dtype matches this spec."""
164177
dtype_str = extract_dtype_str(obj)

src/shapix/_memo.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,29 @@ class ShapeMemo:
3535
structures: dict[str, object] = field(default_factory=dict)
3636
"""Tree structure bindings: ``{"T": <TreeSpec>}``."""
3737

38+
def snapshot(
39+
self,
40+
) -> tuple[
41+
dict[str, int], dict[str, tuple[bool, tuple[int, ...]]], dict[str, object]
42+
]:
43+
"""Capture a copy of all current bindings."""
44+
return self.single.copy(), self.variadic.copy(), self.structures.copy()
45+
46+
def restore(
47+
self,
48+
snap: tuple[
49+
dict[str, int], dict[str, tuple[bool, tuple[int, ...]]], dict[str, object]
50+
],
51+
) -> None:
52+
"""Roll back all bindings to a previous snapshot."""
53+
single, variadic, structures = snap
54+
self.single.clear()
55+
self.single.update(single)
56+
self.variadic.clear()
57+
self.variadic.update(variadic)
58+
self.structures.clear()
59+
self.structures.update(structures)
60+
3861

3962
_local = threading.local()
4063

src/shapix/_shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def check_shape(
194194

195195

196196
def _check_fixed_dims(
197-
spec: tuple[DimSpec, ...] | list[DimSpec], shape: tuple[int, ...], memo: ShapeMemo
197+
spec: tuple[DimSpec, ...], shape: tuple[int, ...], memo: ShapeMemo
198198
) -> str:
199199
for dim, size in zip(spec, shape):
200200
err = _check_one(dim, size, memo)

src/shapix/jax.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -211,36 +211,7 @@ def forward(x: F32[N, C, H, W]) -> BF16[N, C, H, W]: ...
211211
# Like types (scalar | array | nested sequences — for input validation)
212212
# ---------------------------------------------------------------------------
213213

214-
from .numpy import (
215-
BoolLike as BoolLikeNumpy,
216-
C64Like as C64LikeNumpy,
217-
C128Like as C128LikeNumpy,
218-
ComplexLike as ComplexLikeNumpy,
219-
F16Like as F16LikeNumpy,
220-
F32Like as F32LikeNumpy,
221-
F64Like as F64LikeNumpy,
222-
FloatLike as FloatLikeNumpy,
223-
InexactLike as InexactLikeNumpy,
224-
I8Like as I8LikeNumpy,
225-
I16Like as I16LikeNumpy,
226-
I32Like as I32LikeNumpy,
227-
I64Like as I64LikeNumpy,
228-
IntegerLike as IntegerLikeNumpy,
229-
IntLike as IntLikeNumpy,
230-
NumLike as NumLikeNumpy,
231-
RealLike as RealLikeNumpy,
232-
ShapedLike as ShapedLikeNumpy,
233-
U8Like as U8LikeNumpy,
234-
U16Like as U16LikeNumpy,
235-
U32Like as U32LikeNumpy,
236-
U64Like as U64LikeNumpy,
237-
UIntLike as UIntLikeNumpy,
238-
)
239-
240-
# ---------------------------------------------------------------------------
241-
# ScalarLike types (re-exported from numpy — no shape, just value validation)
242-
# ---------------------------------------------------------------------------
243-
214+
# ScalarLike types + factory (re-exported from numpy — no shape, just value)
244215
from .numpy import BoolScalarLike as BoolScalarLike
245216
from .numpy import C64ScalarLike as C64ScalarLike
246217
from .numpy import C128ScalarLike as C128ScalarLike
@@ -268,6 +239,32 @@ def forward(x: F32[N, C, H, W]) -> BF16[N, C, H, W]: ...
268239
from .numpy import make_scalar_like_type as make_scalar_like_type
269240

270241
if tp.TYPE_CHECKING:
242+
from .numpy import (
243+
BoolLike as BoolLikeNumpy,
244+
C64Like as C64LikeNumpy,
245+
C128Like as C128LikeNumpy,
246+
ComplexLike as ComplexLikeNumpy,
247+
F16Like as F16LikeNumpy,
248+
F32Like as F32LikeNumpy,
249+
F64Like as F64LikeNumpy,
250+
FloatLike as FloatLikeNumpy,
251+
InexactLike as InexactLikeNumpy,
252+
I8Like as I8LikeNumpy,
253+
I16Like as I16LikeNumpy,
254+
I32Like as I32LikeNumpy,
255+
I64Like as I64LikeNumpy,
256+
IntegerLike as IntegerLikeNumpy,
257+
IntLike as IntLikeNumpy,
258+
NumLike as NumLikeNumpy,
259+
RealLike as RealLikeNumpy,
260+
ShapedLike as ShapedLikeNumpy,
261+
U8Like as U8LikeNumpy,
262+
U16Like as U16LikeNumpy,
263+
U32Like as U32LikeNumpy,
264+
U64Like as U64LikeNumpy,
265+
UIntLike as UIntLikeNumpy,
266+
)
267+
271268
type BoolLike[*Dims] = Bool[*Dims] | BoolLikeNumpy[*Dims]
272269

273270
type I8Like[*Dims] = I8[*Dims] | I8LikeNumpy[*Dims]

src/shapix/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,12 @@ def make_scalar_like_type(
463463
type
464464
An ``Annotated`` type that beartype validates at runtime.
465465
"""
466+
from ._array_types import _VALID_CASTINGS
467+
468+
if casting not in _VALID_CASTINGS:
469+
msg = f"Invalid casting {casting!r}, must be one of {sorted(_VALID_CASTINGS)}"
470+
raise ValueError(msg)
471+
466472
target = np.dtype(target_dtype)
467473

468474
def _check(value: object) -> bool:

src/shapix/torch.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -206,36 +206,7 @@ def forward(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
206206
# Like types (scalar | tensor | nested sequences — for input validation)
207207
# ---------------------------------------------------------------------------
208208

209-
from .numpy import (
210-
BoolLike as BoolLikeNumpy,
211-
C64Like as C64LikeNumpy,
212-
C128Like as C128LikeNumpy,
213-
ComplexLike as ComplexLikeNumpy,
214-
F16Like as F16LikeNumpy,
215-
F32Like as F32LikeNumpy,
216-
F64Like as F64LikeNumpy,
217-
FloatLike as FloatLikeNumpy,
218-
InexactLike as InexactLikeNumpy,
219-
I8Like as I8LikeNumpy,
220-
I16Like as I16LikeNumpy,
221-
I32Like as I32LikeNumpy,
222-
I64Like as I64LikeNumpy,
223-
IntegerLike as IntegerLikeNumpy,
224-
IntLike as IntLikeNumpy,
225-
NumLike as NumLikeNumpy,
226-
RealLike as RealLikeNumpy,
227-
ShapedLike as ShapedLikeNumpy,
228-
U8Like as U8LikeNumpy,
229-
U16Like as U16LikeNumpy,
230-
U32Like as U32LikeNumpy,
231-
U64Like as U64LikeNumpy,
232-
UIntLike as UIntLikeNumpy,
233-
)
234-
235-
# ---------------------------------------------------------------------------
236-
# ScalarLike types (re-exported from numpy — no shape, just value validation)
237-
# ---------------------------------------------------------------------------
238-
209+
# ScalarLike types + factory (re-exported from numpy — no shape, just value)
239210
from .numpy import BoolScalarLike as BoolScalarLike
240211
from .numpy import C64ScalarLike as C64ScalarLike
241212
from .numpy import C128ScalarLike as C128ScalarLike
@@ -263,6 +234,32 @@ def forward(x: F32[N, C, H, W]) -> F32[N, C, H, W]: ...
263234
from .numpy import make_scalar_like_type as make_scalar_like_type
264235

265236
if tp.TYPE_CHECKING:
237+
from .numpy import (
238+
BoolLike as BoolLikeNumpy,
239+
C64Like as C64LikeNumpy,
240+
C128Like as C128LikeNumpy,
241+
ComplexLike as ComplexLikeNumpy,
242+
F16Like as F16LikeNumpy,
243+
F32Like as F32LikeNumpy,
244+
F64Like as F64LikeNumpy,
245+
FloatLike as FloatLikeNumpy,
246+
InexactLike as InexactLikeNumpy,
247+
I8Like as I8LikeNumpy,
248+
I16Like as I16LikeNumpy,
249+
I32Like as I32LikeNumpy,
250+
I64Like as I64LikeNumpy,
251+
IntegerLike as IntegerLikeNumpy,
252+
IntLike as IntLikeNumpy,
253+
NumLike as NumLikeNumpy,
254+
RealLike as RealLikeNumpy,
255+
ShapedLike as ShapedLikeNumpy,
256+
U8Like as U8LikeNumpy,
257+
U16Like as U16LikeNumpy,
258+
U32Like as U32LikeNumpy,
259+
U64Like as U64LikeNumpy,
260+
UIntLike as UIntLikeNumpy,
261+
)
262+
266263
type BoolLike[*Dims] = Bool[*Dims] | BoolLikeNumpy[*Dims]
267264

268265
type I8Like[*Dims] = I8[*Dims] | I8LikeNumpy[*Dims]

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"test_memo.py",
1818
"test_shape.py",
1919
"test_dimensions.py",
20+
"test_coverage_edges.py",
2021
},
2122
"jax": {"test_jax.py", "test_dtypes.py"},
2223
"torch": {"test_torch.py", "test_dtypes.py"},

tests/test_coverage_edges.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import numpy as np
6+
import pytest
67

78
import shapix._memo as memo_mod
89
from beartype import BeartypeConf
@@ -288,6 +289,47 @@ def test_struct_checker_fail_obj_replay(self) -> None:
288289
assert checker(bad) is False # replay
289290

290291

292+
class TestInputValidation:
293+
def test_invalid_casting_in_make_array_like_type(self) -> None:
294+
from shapix._array_types import make_array_like_type
295+
296+
with pytest.raises(ValueError, match="Invalid casting"):
297+
make_array_like_type(FLOAT32, casting="bogus")
298+
299+
def test_invalid_casting_in_make_scalar_like_type(self) -> None:
300+
from shapix.numpy import make_scalar_like_type
301+
302+
with pytest.raises(ValueError, match="Invalid casting"):
303+
make_scalar_like_type(np.float32, casting="bogus")
304+
305+
def test_invalid_byteorder_in_dtype_spec(self) -> None:
306+
from shapix._dtypes import DtypeSpec
307+
308+
with pytest.raises(ValueError, match="Invalid byteorder"):
309+
DtypeSpec("Bad", frozenset({"float32"}), byteorder="wrong")
310+
311+
def test_valid_castings_accepted(self) -> None:
312+
from shapix._array_types import make_array_like_type
313+
314+
for casting in ("no", "equiv", "safe", "same_kind", "unsafe"):
315+
factory = make_array_like_type(FLOAT32, casting=casting)
316+
assert factory is not None
317+
318+
def test_memo_snapshot_restore(self) -> None:
319+
memo = ShapeMemo(single={"N": 5}, variadic={"B": (False, (2, 3))})
320+
memo.structures["T"] = "spec"
321+
snap = memo.snapshot()
322+
323+
memo.single["C"] = 10
324+
memo.variadic["X"] = (True, (1,))
325+
memo.structures["S"] = "other"
326+
327+
memo.restore(snap)
328+
assert memo.single == {"N": 5}
329+
assert memo.variadic == {"B": (False, (2, 3))}
330+
assert memo.structures == {"T": "spec"}
331+
332+
291333
class TestClawWrapper:
292334
def test_shapix_this_package_delegates_to_beartype(self, monkeypatch) -> None:
293335
captured: dict[str, object] = {}

0 commit comments

Comments
 (0)