Skip to content

Commit 8f16e44

Browse files
committed
Fix replay guard stale false negatives on reused prebuilt hints
When a prebuilt hint (e.g. Hint = F32[N]) was shared across parameters, the replay guard stayed armed indefinitely after a cross-arg failure, blocking subsequent standalone is_bearable() calls with the same object. Replace the permanent guard with a _fail_replays countdown (2) matching beartype's error-gen re-invocations (is_valid + get_diagnosis), so the guard clears after error-gen and standalone validation can proceed.
1 parent 908cffd commit 8f16e44

File tree

3 files changed

+92
-2
lines changed

3 files changed

+92
-2
lines changed

src/shapix/_array_types.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,21 @@ class _StructChecker:
6464
``Float32Array[N, C]``) and reused across all functions that share it.
6565
"""
6666

67-
__slots__ = ("_dtype_spec", "_shape_spec", "_repr", "_fail_obj", "_fail_memo")
67+
__slots__ = (
68+
"_dtype_spec",
69+
"_shape_spec",
70+
"_repr",
71+
"_fail_obj",
72+
"_fail_memo",
73+
"_fail_replays",
74+
)
6875

6976
def __init__(self, dtype_spec: DtypeSpec, shape_spec: tuple[DimSpec, ...]) -> None:
7077
self._dtype_spec = dtype_spec
7178
self._shape_spec = shape_spec
7279
self._fail_obj: object | None = None
7380
self._fail_memo: object | None = None
81+
self._fail_replays: int = 0
7482

7583
# Pre-compute repr for beartype error messages
7684
dims = ", ".join(repr(d) for d in shape_spec)
@@ -95,6 +103,9 @@ def __call__(self, obj: object) -> bool:
95103
# - Different memo WITH prior bindings → fresh @beartype call where
96104
# earlier params already bound dims → clear guard and re-validate.
97105
# - Different memo, empty → beartype error-gen → replay failure.
106+
# Beartype re-invokes twice per failing param (is_valid + get_diagnosis).
107+
# A countdown (_fail_replays) clears the guard after those 2 replays so
108+
# later standalone is_bearable() with the same object can re-validate.
98109
# - Untagged explicit memo (check_context) → always re-validate.
99110
if self._fail_obj is not None and self._fail_obj is obj:
100111
if has_untagged_memo():
@@ -113,7 +124,10 @@ def __call__(self, obj: object) -> bool:
113124
self._fail_memo = None
114125
else:
115126
# Different memo, empty = beartype error-gen. Replay failure.
116-
# Stay armed for additional error-gen re-invocations.
127+
self._fail_replays -= 1
128+
if self._fail_replays <= 0:
129+
self._fail_obj = None
130+
self._fail_memo = None
117131
return False
118132

119133
# Dtype check
@@ -136,6 +150,7 @@ def __call__(self, obj: object) -> bool:
136150
if any(snap): # prior bindings from other params
137151
self._fail_obj = obj
138152
self._fail_memo = memo
153+
self._fail_replays = 2
139154

140155
return result
141156

@@ -230,6 +245,7 @@ class _ArrayLikeChecker:
230245
"_repr",
231246
"_fail_obj",
232247
"_fail_memo",
248+
"_fail_replays",
233249
)
234250

235251
def __init__(
@@ -247,6 +263,7 @@ def __init__(
247263
self._asarray = asarray
248264
self._fail_obj: object | None = None
249265
self._fail_memo: object | None = None
266+
self._fail_replays: int = 0
250267

251268
dims = ", ".join(repr(d) for d in shape_spec)
252269
self._repr = f"{name}[{dims}]"
@@ -266,6 +283,11 @@ def __call__(self, obj: object) -> bool:
266283
self._fail_obj = None
267284
self._fail_memo = None
268285
else:
286+
# Error-gen replay (see _StructChecker comment).
287+
self._fail_replays -= 1
288+
if self._fail_replays <= 0:
289+
self._fail_obj = None
290+
self._fail_memo = None
269291
return False
270292

271293
memo = get_memo(_depth=3)
@@ -279,6 +301,7 @@ def __call__(self, obj: object) -> bool:
279301
if not result and has_prior:
280302
self._fail_obj = obj
281303
self._fail_memo = memo
304+
self._fail_replays = 2
282305
return result
283306

284307
# Slow path: convert scalar / sequence / protocol object to array.
@@ -293,6 +316,7 @@ def __call__(self, obj: object) -> bool:
293316
if not result and has_prior:
294317
self._fail_obj = obj
295318
self._fail_memo = memo
319+
self._fail_replays = 2
296320
return result
297321

298322
def _convert(self, obj: object) -> object | None:

src/shapix/_tree.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class _TreeChecker:
101101
"_repr",
102102
"_fail_obj",
103103
"_fail_memo",
104+
"_fail_replays",
104105
)
105106

106107
def __init__(
@@ -115,6 +116,7 @@ def __init__(
115116
self._get_ops = get_ops
116117
self._fail_obj: object | None = None
117118
self._fail_memo: object | None = None
119+
self._fail_replays: int = 0
118120
spec_str = f", {structure_spec}" if structure_spec else ""
119121
self._repr = f"Tree[{leaf_type!r}{spec_str}]"
120122

@@ -135,6 +137,11 @@ def __call__(self, obj: object) -> bool:
135137
self._fail_obj = None
136138
self._fail_memo = None
137139
else:
140+
# Error-gen replay (see _StructChecker comment).
141+
self._fail_replays -= 1
142+
if self._fail_replays <= 0:
143+
self._fail_obj = None
144+
self._fail_memo = None
138145
return False
139146

140147
tree_ops = self._get_ops()
@@ -158,6 +165,7 @@ def __call__(self, obj: object) -> bool:
158165
if any(snap): # prior bindings from other params
159166
self._fail_obj = obj
160167
self._fail_memo = memo
168+
self._fail_replays = 2
161169
else:
162170
self._fail_obj = None
163171
self._fail_memo = None

tests/test_decorator.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,3 +709,61 @@ def g(x: Tree[F32[N], T]) -> None: # type: ignore[valid-type]
709709
pass
710710

711711
g(y) # must pass
712+
713+
714+
class TestReplayGuardReusedHint:
715+
"""Prebuilt hint must not poison standalone validation after cross-arg failure."""
716+
717+
def test_struct_reused_hint_standalone(self) -> None:
718+
"""is_bearable passes for reused F32[N] hint after cross-arg failure."""
719+
from beartype.door import is_bearable
720+
721+
Hint = F32[N]
722+
b = np.ones(3, dtype=np.float32)
723+
724+
@shapix.check
725+
@beartype
726+
def pair(a: Hint, b: Hint) -> None:
727+
pass
728+
729+
with pytest.raises(BeartypeCallHintParamViolation):
730+
pair(np.ones(4, dtype=np.float32), b) # b fails: N=4 vs shape (3,)
731+
732+
assert is_bearable(b, Hint) # must pass: N unbound, shape (3,) binds N=3
733+
734+
def test_arraylike_reused_hint_standalone(self) -> None:
735+
"""is_bearable passes for reused F32Like[N] hint after cross-arg failure."""
736+
from beartype.door import is_bearable
737+
from shapix.numpy import F32Like
738+
739+
Hint = F32Like[N]
740+
lst = [1.0, 2.0, 3.0]
741+
742+
@shapix.check
743+
@beartype
744+
def pair(a: Hint, b: Hint) -> None:
745+
pass
746+
747+
with pytest.raises(BeartypeCallHintParamViolation):
748+
pair([1.0, 2.0, 3.0, 4.0], lst) # lst fails: N=4 vs len 3
749+
750+
assert is_bearable(lst, Hint) # must pass
751+
752+
def test_tree_reused_hint_standalone(self) -> None:
753+
"""is_bearable passes for reused Tree hint after cross-structure failure."""
754+
pytest.importorskip("optree")
755+
from beartype.door import is_bearable
756+
from shapix.optree import Tree
757+
758+
Hint = Tree[F32[N], T] # type: ignore[type-arg]
759+
y = [np.ones(3, dtype=np.float32)]
760+
761+
@shapix.check
762+
@beartype
763+
def pair(x: Hint, y: Hint) -> None: # type: ignore[valid-type]
764+
pass
765+
766+
with pytest.raises(BeartypeCallHintParamViolation):
767+
pair({"a": np.ones(3, dtype=np.float32)}, y)
768+
769+
assert is_bearable(y, Hint) # must pass: T unbound, list structure OK

0 commit comments

Comments
 (0)