diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index d86b9ef62..a1e8c12eb 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -52,6 +52,13 @@ HAS_TORCH = True except ImportError: HAS_TORCH = False +try: + import jax # type: ignore + import jax.numpy as jnp # type: ignore + + HAS_JAX = True +except ImportError: + HAS_JAX = False def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 @@ -106,6 +113,14 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")} return comparator(orig_dict, new_dict, superset_obj) + # Handle JAX arrays first to avoid boolean context errors in other conditions + if HAS_JAX and isinstance(orig, jax.Array): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + return bool(jnp.allclose(orig, new, equal_nan=True)) + if HAS_SQLALCHEMY: try: insp = sqlalchemy.inspection.inspect(orig) diff --git a/tests/test_comparator.py b/tests/test_comparator.py index dfeedba83..06e692b39 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -710,6 +710,81 @@ def test_torch(): assert not comparator(gg, ii) +def test_jax(): + try: + import jax.numpy as jnp + except ImportError: + pytest.skip() + + # Test basic arrays + a = jnp.array([1, 2, 3]) + b = jnp.array([1, 2, 3]) + c = jnp.array([1, 2, 4]) + assert comparator(a, b) + assert not comparator(a, c) + + # Test 2D arrays + d = jnp.array([[1, 2, 3], [4, 5, 6]]) + e = jnp.array([[1, 2, 3], [4, 5, 6]]) + f = jnp.array([[1, 2, 3], [4, 5, 7]]) + assert comparator(d, e) + assert not comparator(d, f) + + # Test arrays with different data types + g = jnp.array([1, 2, 3], dtype=jnp.float32) + h = jnp.array([1, 2, 3], dtype=jnp.float32) + i = jnp.array([1, 2, 3], dtype=jnp.int32) + assert comparator(g, h) + assert not comparator(g, i) + + # Test 3D arrays + j = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + k = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + l = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 9]]]) + assert comparator(j, k) + assert not comparator(j, l) + + # Test arrays with different shapes + m = jnp.array([1, 2, 3]) + n = jnp.array([[1, 2, 3]]) + assert not comparator(m, n) + + # Test empty arrays + o = jnp.array([]) + p = jnp.array([]) + q = jnp.array([1]) + assert comparator(o, p) + assert not comparator(o, q) + + # Test arrays with NaN values + r = jnp.array([1.0, jnp.nan, 3.0]) + s = jnp.array([1.0, jnp.nan, 3.0]) + t = jnp.array([1.0, 2.0, 3.0]) + assert comparator(r, s) # NaN == NaN + assert not comparator(r, t) + + # Test arrays with infinity values + u = jnp.array([1.0, jnp.inf, 3.0]) + v = jnp.array([1.0, jnp.inf, 3.0]) + w = jnp.array([1.0, -jnp.inf, 3.0]) + assert comparator(u, v) + assert not comparator(u, w) + + # Test complex arrays + x = jnp.array([1+2j, 3+4j]) + y = jnp.array([1+2j, 3+4j]) + z = jnp.array([1+2j, 3+5j]) + assert comparator(x, y) + assert not comparator(x, z) + + # Test boolean arrays + aa = jnp.array([True, False, True]) + bb = jnp.array([True, False, True]) + cc = jnp.array([True, True, True]) + assert comparator(aa, bb) + assert not comparator(aa, cc) + + def test_returns(): a = Success(5) b = Success(5)