Skip to content

Commit 47e29ec

Browse files
Merge pull request #584 from codeflash-ai/support-jax-comparator
jax comparator
2 parents ec531ba + e15bb84 commit 47e29ec

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

codeflash/verification/comparator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@
5252
HAS_TORCH = True
5353
except ImportError:
5454
HAS_TORCH = False
55+
try:
56+
import jax # type: ignore
57+
import jax.numpy as jnp # type: ignore
58+
59+
HAS_JAX = True
60+
except ImportError:
61+
HAS_JAX = False
5562

5663

5764
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -106,6 +113,14 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
106113
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
107114
return comparator(orig_dict, new_dict, superset_obj)
108115

116+
# Handle JAX arrays first to avoid boolean context errors in other conditions
117+
if HAS_JAX and isinstance(orig, jax.Array):
118+
if orig.dtype != new.dtype:
119+
return False
120+
if orig.shape != new.shape:
121+
return False
122+
return bool(jnp.allclose(orig, new, equal_nan=True))
123+
109124
if HAS_SQLALCHEMY:
110125
try:
111126
insp = sqlalchemy.inspection.inspect(orig)

tests/test_comparator.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,81 @@ def test_torch():
710710
assert not comparator(gg, ii)
711711

712712

713+
def test_jax():
714+
try:
715+
import jax.numpy as jnp
716+
except ImportError:
717+
pytest.skip()
718+
719+
# Test basic arrays
720+
a = jnp.array([1, 2, 3])
721+
b = jnp.array([1, 2, 3])
722+
c = jnp.array([1, 2, 4])
723+
assert comparator(a, b)
724+
assert not comparator(a, c)
725+
726+
# Test 2D arrays
727+
d = jnp.array([[1, 2, 3], [4, 5, 6]])
728+
e = jnp.array([[1, 2, 3], [4, 5, 6]])
729+
f = jnp.array([[1, 2, 3], [4, 5, 7]])
730+
assert comparator(d, e)
731+
assert not comparator(d, f)
732+
733+
# Test arrays with different data types
734+
g = jnp.array([1, 2, 3], dtype=jnp.float32)
735+
h = jnp.array([1, 2, 3], dtype=jnp.float32)
736+
i = jnp.array([1, 2, 3], dtype=jnp.int32)
737+
assert comparator(g, h)
738+
assert not comparator(g, i)
739+
740+
# Test 3D arrays
741+
j = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
742+
k = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
743+
l = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
744+
assert comparator(j, k)
745+
assert not comparator(j, l)
746+
747+
# Test arrays with different shapes
748+
m = jnp.array([1, 2, 3])
749+
n = jnp.array([[1, 2, 3]])
750+
assert not comparator(m, n)
751+
752+
# Test empty arrays
753+
o = jnp.array([])
754+
p = jnp.array([])
755+
q = jnp.array([1])
756+
assert comparator(o, p)
757+
assert not comparator(o, q)
758+
759+
# Test arrays with NaN values
760+
r = jnp.array([1.0, jnp.nan, 3.0])
761+
s = jnp.array([1.0, jnp.nan, 3.0])
762+
t = jnp.array([1.0, 2.0, 3.0])
763+
assert comparator(r, s) # NaN == NaN
764+
assert not comparator(r, t)
765+
766+
# Test arrays with infinity values
767+
u = jnp.array([1.0, jnp.inf, 3.0])
768+
v = jnp.array([1.0, jnp.inf, 3.0])
769+
w = jnp.array([1.0, -jnp.inf, 3.0])
770+
assert comparator(u, v)
771+
assert not comparator(u, w)
772+
773+
# Test complex arrays
774+
x = jnp.array([1+2j, 3+4j])
775+
y = jnp.array([1+2j, 3+4j])
776+
z = jnp.array([1+2j, 3+5j])
777+
assert comparator(x, y)
778+
assert not comparator(x, z)
779+
780+
# Test boolean arrays
781+
aa = jnp.array([True, False, True])
782+
bb = jnp.array([True, False, True])
783+
cc = jnp.array([True, True, True])
784+
assert comparator(aa, bb)
785+
assert not comparator(aa, cc)
786+
787+
713788
def test_returns():
714789
a = Success(5)
715790
b = Success(5)

0 commit comments

Comments
 (0)