-
Notifications
You must be signed in to change notification settings - Fork 86
Closed
Description
reproducible example (HT claude sonnet 4.5) included below.
jaxtyping is extraordinarily useful. huge thanks for the work here @patrick-kidger. let me know if this is a confirmed bug that I might be able to dig in on.
"""Minimal reproducible example for jaxtyping + typeguard bug with PEP 604 unions.
This demonstrates that T | None syntax fails to validate jaxtyping array types
while Optional[T] and Union[T, None] work correctly.
Issue: When using @jaxtyped with typeguard, the PEP 604 union syntax (T | None)
does not properly validate jaxtyping array type constraints like dtype.
Run with uv (no installation required):
uv run jaxtyping_union_bug_reprex.py
/// script
requires-python = ">=3.10"
dependencies = [
"jax>=0.4.31",
"jaxtyping>=0.3.0",
"typeguard==2.13.3",
]
///
"""
from typing import Optional, Union
import jax.numpy as jnp
from jaxtyping import Int, jaxtyped
from typeguard import typechecked as typechecker
# These three functions should behave identically
@jaxtyped(typechecker=typechecker)
def with_optional(x: Optional[Int[jnp.ndarray, " N"]]) -> int:
"""Using Optional[T] syntax."""
return 1
@jaxtyped(typechecker=typechecker)
def with_union(x: Union[Int[jnp.ndarray, " N"], None]) -> int:
"""Using Union[T, None] syntax."""
return 2
@jaxtyped(typechecker=typechecker)
def with_pipe(x: Int[jnp.ndarray, " N"] | None) -> int:
"""Using T | None syntax (PEP 604)."""
return 3
def main():
"""Demonstrate the bug."""
int_array = jnp.array([1, 2, 3], dtype=jnp.int32)
float_array = jnp.array([1.0, 2.0, 3.0]) # Wrong dtype!
print("=" * 70)
print("Testing with correct int32 array")
print("=" * 70)
print(f"with_optional: {with_optional(int_array)} ✓")
print(f"with_union: {with_union(int_array)} ✓")
print(f"with_pipe: {with_pipe(int_array)} ✓")
print("\n" + "=" * 70)
print("Testing with WRONG dtype (float32 instead of int32)")
print("=" * 70)
# Optional[T] - should reject
try:
with_optional(float_array)
print("with_optional: ACCEPTED float ✗ BUG!")
except Exception:
print("with_optional: REJECTED float ✓")
# Union[T, None] - should reject
try:
with_union(float_array)
print("with_union: ACCEPTED float ✗ BUG!")
except Exception:
print("with_union: REJECTED float ✓")
# T | None - BROKEN: does not reject
try:
with_pipe(float_array)
print("with_pipe: ACCEPTED float ✗ BUG!")
except Exception:
print("with_pipe: REJECTED float ✓")
print("\n" + "=" * 70)
print("Expected: All three should reject the float array")
print("Actual: Only Optional[T] and Union[T, None] reject it")
print("=" * 70)
if __name__ == "__main__":
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels