Skip to content

Commit 1161111

Browse files
Now more robust to errors during type hint resolution.
1 parent 4af24af commit 1161111

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

jaxtyping/_decorator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,12 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore
393393
full_signature = inspect.signature(fn)
394394
try:
395395
destring_annotations = get_type_hints(fn, include_extras=True)
396-
except NameError:
396+
except Exception:
397397
# Best-effort attempt to destringify annotations.
398+
# Not just `NameError` but also e.g. `ValueError` in case we have e.g.
399+
# 'Float[Foo, "*foo *bar"]' and raise from having multiple variadic
400+
# arguments. Sometimes this can still be useful to use for human
401+
# documentation purposes.
398402
pass
399403
else:
400404
new_params = []

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "jaxtyping"
3-
version = "0.3.1"
3+
version = "0.3.2"
44
description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays."
55
readme = "README.md"
66
requires-python =">=3.10"

test/test_decorator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,11 @@ def __init__(self, x: int):
305305
wrapped = Foo.__init__.__wrapped__
306306
with pytest.raises(AttributeError):
307307
wrapped.__wrapped__
308+
309+
310+
def test_stringified_multiple_varaidic(typecheck):
311+
@jaxtyped(typechecker=typecheck)
312+
def foo() -> 'Float[Array, "*foo *bar"]':
313+
return jnp.arange(3)
314+
315+
foo()

0 commit comments

Comments
 (0)