Skip to content

Optionals via PEP 604 unions fail to validate with typeguard 2.13.3 #349

@brendancooley

Description

@brendancooley

reproducible example (HT claude sonnet 4.5) included below.

jaxtyping is extraordinarily useful. huge thanks for the work here @patrick-kidger. let me know if this is a confirmed bug that I might be able to dig in on.

"""Minimal reproducible example for jaxtyping + typeguard bug with PEP 604 unions.

This demonstrates that T | None syntax fails to validate jaxtyping array types
while Optional[T] and Union[T, None] work correctly.

Issue: When using @jaxtyped with typeguard, the PEP 604 union syntax (T | None)
does not properly validate jaxtyping array type constraints like dtype.

Run with uv (no installation required):
    uv run jaxtyping_union_bug_reprex.py

/// script
requires-python = ">=3.10"
dependencies = [
    "jax>=0.4.31",
    "jaxtyping>=0.3.0",
    "typeguard==2.13.3",
]
///
"""

from typing import Optional, Union

import jax.numpy as jnp
from jaxtyping import Int, jaxtyped
from typeguard import typechecked as typechecker


# These three functions should behave identically
@jaxtyped(typechecker=typechecker)
def with_optional(x: Optional[Int[jnp.ndarray, " N"]]) -> int:
    """Using Optional[T] syntax."""
    return 1


@jaxtyped(typechecker=typechecker)
def with_union(x: Union[Int[jnp.ndarray, " N"], None]) -> int:
    """Using Union[T, None] syntax."""
    return 2


@jaxtyped(typechecker=typechecker)
def with_pipe(x: Int[jnp.ndarray, " N"] | None) -> int:
    """Using T | None syntax (PEP 604)."""
    return 3


def main():
    """Demonstrate the bug."""
    int_array = jnp.array([1, 2, 3], dtype=jnp.int32)
    float_array = jnp.array([1.0, 2.0, 3.0])  # Wrong dtype!

    print("=" * 70)
    print("Testing with correct int32 array")
    print("=" * 70)
    print(f"with_optional: {with_optional(int_array)} ✓")
    print(f"with_union:    {with_union(int_array)} ✓")
    print(f"with_pipe:     {with_pipe(int_array)} ✓")

    print("\n" + "=" * 70)
    print("Testing with WRONG dtype (float32 instead of int32)")
    print("=" * 70)

    # Optional[T] - should reject
    try:
        with_optional(float_array)
        print("with_optional: ACCEPTED float ✗ BUG!")
    except Exception:
        print("with_optional: REJECTED float ✓")

    # Union[T, None] - should reject
    try:
        with_union(float_array)
        print("with_union:    ACCEPTED float ✗ BUG!")
    except Exception:
        print("with_union:    REJECTED float ✓")

    # T | None - BROKEN: does not reject
    try:
        with_pipe(float_array)
        print("with_pipe:     ACCEPTED float ✗ BUG!")
    except Exception:
        print("with_pipe:     REJECTED float ✓")

    print("\n" + "=" * 70)
    print("Expected: All three should reject the float array")
    print("Actual:   Only Optional[T] and Union[T, None] reject it")
    print("=" * 70)


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions