Skip to content

Commit b94f8cd

Browse files
committed
Fix memo leaking to nested beartype helpers, add TreeChecker replay guard, document Tree structure syntax as runtime-only
- Tag explicit memo entries with owner_code so @shapix.check scope does not bleed into nested plain @beartype calls (Finding 1) - Add _fail_obj replay guard and memo snapshot/restore to _TreeChecker to prevent contradictory beartype error messages (Finding 3) - Document Tree structure args (T, S, ...) as runtime-only in module docstring, README, and typing fixture (Finding 2)
1 parent 72bb61d commit b94f8cd

File tree

7 files changed

+214
-9
lines changed

7 files changed

+214
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ However, some patterns are fundamentally runtime-only and produce type checker e
799799
| Arithmetic | `F32[N + 2]` | `# type: ignore` |
800800
| Custom dimensions | `F32[Vocab, Embed]` | `# type: ignore` or `TYPE_CHECKING` pattern |
801801
| `Value(...)` | `F32[Value("size")]` | `# type: ignore` |
802+
| Tree structure args | `Tree[F32[N], T]`, `Tree[F32[N], T, ...]` | `# type: ignore` — leaf-only `Tree[F32[N, C]]` works |
802803

803804
### Recommended type checker config
804805

src/shapix/_decorator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,27 @@ def decorator(fn: Callable[P, R]) -> Callable[P, R]:
5959
inner = beartype(fn, conf=conf) # type: ignore[call-overload]
6060

6161
if inspect.iscoroutinefunction(fn):
62+
inner_code = getattr(inner, "__code__", None)
6263

6364
@functools.wraps(fn)
6465
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
6566
bound = signature.bind_partial(*args, **kwargs)
6667
bound.apply_defaults()
67-
push_memo(scope=dict(bound.arguments))
68+
push_memo(scope=dict(bound.arguments), owner_code=inner_code)
6869
try:
6970
return await inner(*args, **kwargs) # type: ignore[misc,no-any-return]
7071
finally:
7172
pop_memo()
7273

7374
return async_wrapper # type: ignore[return-value]
7475

76+
inner_code = getattr(inner, "__code__", None)
77+
7578
@functools.wraps(fn)
7679
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
7780
bound = signature.bind_partial(*args, **kwargs)
7881
bound.apply_defaults()
79-
push_memo(scope=dict(bound.arguments))
82+
push_memo(scope=dict(bound.arguments), owner_code=inner_code)
8083
try:
8184
return inner(*args, **kwargs)
8285
finally:

src/shapix/_memo.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,23 +79,39 @@ def restore(
7979
_explicit_scope_stack: contextvars.ContextVar[tuple[dict[str, object] | None, ...]] = (
8080
contextvars.ContextVar("_explicit_scope_stack", default=())
8181
)
82+
_explicit_owner_stack: contextvars.ContextVar[tuple[types.CodeType | None, ...]] = (
83+
contextvars.ContextVar("_explicit_owner_stack", default=())
84+
)
8285

8386

8487
def push_memo(
85-
memo: ShapeMemo | None = None, *, scope: dict[str, object] | None = None
88+
memo: ShapeMemo | None = None,
89+
*,
90+
scope: dict[str, object] | None = None,
91+
owner_code: types.CodeType | None = None,
8692
) -> ShapeMemo:
87-
"""Push a memo onto the explicit stack. Pair with :func:`pop_memo`."""
93+
"""Push a memo onto the explicit stack. Pair with :func:`pop_memo`.
94+
95+
Parameters
96+
----------
97+
owner_code:
98+
When set, this entry is only visible to frames whose ``f_code``
99+
matches. Untagged entries (``None``) are unconditional and used
100+
by :class:`check_context` and :class:`_TreeChecker`.
101+
"""
88102
if memo is None:
89103
memo = ShapeMemo()
90104
_explicit_stack.set(_explicit_stack.get() + (memo,))
91105
_explicit_scope_stack.set(_explicit_scope_stack.get() + (scope,))
106+
_explicit_owner_stack.set(_explicit_owner_stack.get() + (owner_code,))
92107
return memo
93108

94109

95110
def pop_memo() -> None:
96111
"""Pop the most recent explicit memo."""
97112
_explicit_stack.set(_explicit_stack.get()[:-1])
98113
_explicit_scope_stack.set(_explicit_scope_stack.get()[:-1])
114+
_explicit_owner_stack.set(_explicit_owner_stack.get()[:-1])
99115

100116

101117
# ---------------------------------------------------------------------------
@@ -122,10 +138,22 @@ def get_memo(_depth: int = 2) -> ShapeMemo:
122138
Default ``2`` accounts for: our validator → beartype's ``_is_valid_bool``
123139
→ beartype wrapper.
124140
"""
125-
# 1. Explicit stack takes priority
141+
# 1. Explicit stack takes priority (if owner matches or untagged)
126142
explicit = _explicit_stack.get()
127143
if explicit:
128-
return explicit[-1]
144+
owners = _explicit_owner_stack.get()
145+
owner_code = owners[-1] if owners else None
146+
if owner_code is None:
147+
# Untagged entry (check_context / TreeChecker) — always visible
148+
return explicit[-1]
149+
# Tagged entry — only visible if the caller's frame matches
150+
try:
151+
frame = sys._getframe(_depth)
152+
if frame.f_code is owner_code:
153+
return explicit[-1]
154+
except ValueError:
155+
pass
156+
# Fall through to frame-based detection
129157

130158
# 2. Frame-based detection
131159
try:
@@ -187,7 +215,19 @@ def get_scope(_depth: int = 2) -> dict[str, object]:
187215
"""Return the current runtime scope for ``Value(...)`` expressions."""
188216
explicit = _explicit_scope_stack.get()
189217
if explicit and explicit[-1] is not None:
190-
return explicit[-1]
218+
owners = _explicit_owner_stack.get()
219+
owner_code = owners[-1] if owners else None
220+
if owner_code is None:
221+
# Untagged entry — always visible
222+
return explicit[-1]
223+
# Tagged entry — only visible if the caller's frame matches
224+
try:
225+
frame = sys._getframe(_depth)
226+
if frame.f_code is owner_code:
227+
return explicit[-1]
228+
except ValueError:
229+
pass
230+
# Fall through to frame-based detection
191231

192232
try:
193233
frame = sys._getframe(_depth)

src/shapix/_tree.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
Requires ``optree`` or ``jax`` for tree traversal. Install with
1010
``pip install optree`` or ``pip install jax``.
1111
12+
.. note::
13+
14+
Structure arguments (``T``, ``S``, ``...``) are **runtime-only**.
15+
Type checkers see ``Tree`` as ``Tree[LeafType]`` (one type parameter)
16+
and cannot validate multi-arg structure syntax like ``Tree[F32[N], T]``.
17+
Leaf-only annotations such as ``Tree[F32[N, C]]`` are fully supported
18+
by all type checkers.
19+
1220
Import ``Tree`` from an explicit backend module::
1321
1422
from shapix.optree import Tree # backed by optree
@@ -86,7 +94,7 @@ def __repr__(self) -> str:
8694
class _TreeChecker:
8795
"""Beartype validator for tree leaf types and structure consistency."""
8896

89-
__slots__ = ("_leaf_type", "_structure_spec", "_get_ops", "_repr")
97+
__slots__ = ("_leaf_type", "_structure_spec", "_get_ops", "_repr", "_fail_obj")
9098

9199
def __init__(
92100
self,
@@ -98,10 +106,20 @@ def __init__(
98106
self._leaf_type = leaf_type
99107
self._structure_spec = structure_spec
100108
self._get_ops = get_ops
109+
self._fail_obj: object | None = None
101110
spec_str = f", {structure_spec}" if structure_spec else ""
102111
self._repr = f"Tree[{leaf_type!r}{spec_str}]"
103112

104113
def __call__(self, obj: object) -> bool:
114+
# Replay guard: when beartype's error-generation code re-invokes us
115+
# from a different call-stack frame, the fresh memo would lack prior
116+
# bindings, causing a previously failing check to pass. Keep replaying
117+
# the failure until a successful validation with a (possibly different)
118+
# object clears it. We use ``is`` identity (not ``id()``) to avoid
119+
# false matches from recycled addresses.
120+
if self._fail_obj is not None and self._fail_obj is obj:
121+
return False
122+
105123
tree_ops = self._get_ops()
106124
from beartype.door import is_bearable
107125

@@ -111,12 +129,21 @@ def __call__(self, obj: object) -> bool:
111129
# and can resolve ``Value(...)`` expressions against the same parameters.
112130
memo = get_memo(_depth=3)
113131
scope = get_scope(_depth=3)
132+
snap = memo.snapshot()
114133
push_memo(memo, scope=scope)
115134
try:
116-
return self._validate(obj, tree_ops, is_bearable, memo)
135+
result = self._validate(obj, tree_ops, is_bearable, memo)
117136
finally:
118137
pop_memo()
119138

139+
if not result:
140+
memo.restore(snap)
141+
self._fail_obj = obj
142+
else:
143+
self._fail_obj = None
144+
145+
return result
146+
120147
def _validate(
121148
self, obj: object, tree_ops: tp.Any, is_bearable: tp.Any, memo: ShapeMemo
122149
) -> bool:

tests/test_decorator.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,89 @@ def f(size: int) -> F32[Value("size")]: # type: ignore[valid-type]
208208
assert f(7).shape == (7,)
209209

210210

211+
class TestMemoIsolation:
212+
"""Explicit memo from @shapix.check must not leak to nested @beartype helpers."""
213+
214+
def test_check_outer_plain_beartype_inner_independent_dims(self) -> None:
215+
"""Outer @shapix.check binds N=4, inner plain @beartype binds N=7."""
216+
217+
@beartype
218+
def inner(x: F32[N]) -> F32[N]:
219+
return x
220+
221+
@shapix.check
222+
@beartype
223+
def outer(x: F32[N]) -> F32[N]:
224+
# Call inner with a different N — must succeed independently
225+
inner(np.ones(7, dtype=np.float32))
226+
return x
227+
228+
outer(np.ones(4, dtype=np.float32))
229+
230+
def test_check_outer_plain_beartype_inner_value(self) -> None:
231+
"""Inner plain @beartype with Value resolves from its own frame."""
232+
233+
@beartype
234+
def inner(size: int) -> F32[Value("size")]: # type: ignore[valid-type]
235+
return np.ones(size, dtype=np.float32)
236+
237+
@shapix.check
238+
@beartype
239+
def outer(dummy: int) -> F32[Value("dummy")]: # type: ignore[valid-type]
240+
inner(7)
241+
return np.ones(dummy, dtype=np.float32)
242+
243+
outer(4)
244+
245+
def test_check_context_shares_memo_with_beartype_call(self) -> None:
246+
"""Inside check_context(), @beartype calls share the same memo (by design)."""
247+
248+
@beartype
249+
def helper(x: F32[N]) -> F32[N]:
250+
return x
251+
252+
with shapix.check_context():
253+
from beartype.door import is_bearable
254+
255+
assert is_bearable(np.ones(4, dtype=np.float32), F32[N])
256+
# check_context is untagged — helper shares the memo, so N=4 is bound
257+
helper(np.ones(4, dtype=np.float32)) # same N — OK
258+
with pytest.raises(BeartypeCallHintParamViolation):
259+
helper(np.ones(7, dtype=np.float32)) # different N — fails
260+
261+
def test_async_check_outer_plain_beartype_inner(self) -> None:
262+
"""Async variant: outer @shapix.check does not leak to inner @beartype."""
263+
import asyncio
264+
265+
@beartype
266+
def inner(x: F32[N]) -> F32[N]:
267+
return x
268+
269+
@shapix.check
270+
@beartype
271+
async def outer(x: F32[N]) -> F32[N]:
272+
inner(np.ones(7, dtype=np.float32))
273+
return x
274+
275+
asyncio.run(outer(np.ones(4, dtype=np.float32)))
276+
277+
def test_async_child_task_independent_memo(self) -> None:
278+
"""Child task spawned inside @shapix.check gets isolated memo after parent returns."""
279+
import asyncio
280+
281+
@shapix.check
282+
@beartype
283+
async def parent(x: F32[N]) -> F32[N]:
284+
return x
285+
286+
async def run() -> None:
287+
await parent(np.ones(4, dtype=np.float32))
288+
# After parent returns, a new call with different N works
289+
await parent(np.ones(10, dtype=np.float32))
290+
291+
asyncio.run(run())
292+
293+
211294
class TestAsyncCheckContext:
212295
def test_async_check_context(self) -> None:
213296
"""async with check_context() works correctly."""

tests/test_tree.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,44 @@ def test_parse_spec_args_empty_tuple_error(self) -> None:
13001300
Tree[()]
13011301

13021302

1303+
# =====================================================================
1304+
# Replay guard (_fail_obj) — non-contradictory error messages
1305+
# =====================================================================
1306+
1307+
1308+
class TestReplayGuard:
1309+
"""_TreeChecker must produce non-contradictory beartype error messages."""
1310+
1311+
def test_plain_beartype_tree_param_mismatch_error_message(self) -> None:
1312+
"""Plain @beartype tree param mismatch must not say True == Is[...]."""
1313+
1314+
@shapix.check
1315+
@beartype
1316+
def f(x: Tree[F32[N], T], y: Tree[F32[N], T]) -> Tree[F32[N]]:
1317+
return x
1318+
1319+
x = {"a": np.ones(3, dtype=np.float32)}
1320+
y = [np.ones(3, dtype=np.float32)] # different structure
1321+
with pytest.raises(BeartypeCallHintParamViolation) as exc_info:
1322+
f(x, y)
1323+
# The error message should NOT contain "True ==" which would indicate
1324+
# the replay guard failed (validator passed on re-invocation)
1325+
assert "True ==" not in str(exc_info.value)
1326+
1327+
def test_plain_beartype_tree_return_mismatch_error_message(self) -> None:
1328+
"""Return type tree mismatch must produce non-contradictory error."""
1329+
1330+
@shapix.check
1331+
@beartype
1332+
def f(x: Tree[F32[N], T]) -> Tree[F32[N], T]:
1333+
# Return different structure than input
1334+
return [np.ones(3, dtype=np.float32)]
1335+
1336+
with pytest.raises(BeartypeCallHintReturnViolation) as exc_info:
1337+
f({"a": np.ones(3, dtype=np.float32)})
1338+
assert "True ==" not in str(exc_info.value)
1339+
1340+
13031341
# =====================================================================
13041342
# Backend-specific Tree modules
13051343
# =====================================================================

tests/typing/check_tree.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,16 @@ def identity(x: Tree[object]) -> Tree[object]:
6767

6868
def to_none(x: Tree[object]) -> None:
6969
pass
70+
71+
72+
# ---------------------------------------------------------------------------
73+
# Note: Structure args (T, S, ...) are runtime-only
74+
# ---------------------------------------------------------------------------
75+
# Multi-arg syntax like Tree[F32[N], T] or Tree[F32[N], T, ...] works at
76+
# runtime but is not understood by type checkers (Tree has one type param).
77+
# Use `# type: ignore` for structure-bearing annotations.
78+
# Leaf-only annotations like Tree[int] or Tree[object] work with all checkers.
79+
80+
81+
def leaf_int(x: Tree[int]) -> Tree[int]:
82+
return x

0 commit comments

Comments
 (0)