Skip to content

Commit c596e12

Browse files
authored
Merge branch 'main' into refinement
2 parents 65d766d + 52083c4 commit c596e12

File tree

4 files changed

+97
-15
lines changed

4 files changed

+97
-15
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
strategy:
1212
fail-fast: false
1313
matrix:
14-
python-version: ["3.9", "3.10", "3.11", "3.12"]
14+
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
1515
continue-on-error: true
1616
runs-on: ubuntu-latest
1717
steps:

codeflash/verification/comparator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@
5252
HAS_TORCH = True
5353
except ImportError:
5454
HAS_TORCH = False
55+
try:
56+
import jax # type: ignore
57+
import jax.numpy as jnp # type: ignore
58+
59+
HAS_JAX = True
60+
except ImportError:
61+
HAS_JAX = False
5562

5663

5764
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -106,6 +113,14 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
106113
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
107114
return comparator(orig_dict, new_dict, superset_obj)
108115

116+
# Handle JAX arrays first to avoid boolean context errors in other conditions
117+
if HAS_JAX and isinstance(orig, jax.Array):
118+
if orig.dtype != new.dtype:
119+
return False
120+
if orig.shape != new.shape:
121+
return False
122+
return bool(jnp.allclose(orig, new, equal_nan=True))
123+
109124
if HAS_SQLALCHEMY:
110125
try:
111126
insp = sqlalchemy.inspection.inspect(orig)

tests/test_comparator.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,81 @@ def test_torch():
710710
assert not comparator(gg, ii)
711711

712712

713+
def test_jax():
714+
try:
715+
import jax.numpy as jnp
716+
except ImportError:
717+
pytest.skip()
718+
719+
# Test basic arrays
720+
a = jnp.array([1, 2, 3])
721+
b = jnp.array([1, 2, 3])
722+
c = jnp.array([1, 2, 4])
723+
assert comparator(a, b)
724+
assert not comparator(a, c)
725+
726+
# Test 2D arrays
727+
d = jnp.array([[1, 2, 3], [4, 5, 6]])
728+
e = jnp.array([[1, 2, 3], [4, 5, 6]])
729+
f = jnp.array([[1, 2, 3], [4, 5, 7]])
730+
assert comparator(d, e)
731+
assert not comparator(d, f)
732+
733+
# Test arrays with different data types
734+
g = jnp.array([1, 2, 3], dtype=jnp.float32)
735+
h = jnp.array([1, 2, 3], dtype=jnp.float32)
736+
i = jnp.array([1, 2, 3], dtype=jnp.int32)
737+
assert comparator(g, h)
738+
assert not comparator(g, i)
739+
740+
# Test 3D arrays
741+
j = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
742+
k = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
743+
l = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
744+
assert comparator(j, k)
745+
assert not comparator(j, l)
746+
747+
# Test arrays with different shapes
748+
m = jnp.array([1, 2, 3])
749+
n = jnp.array([[1, 2, 3]])
750+
assert not comparator(m, n)
751+
752+
# Test empty arrays
753+
o = jnp.array([])
754+
p = jnp.array([])
755+
q = jnp.array([1])
756+
assert comparator(o, p)
757+
assert not comparator(o, q)
758+
759+
# Test arrays with NaN values
760+
r = jnp.array([1.0, jnp.nan, 3.0])
761+
s = jnp.array([1.0, jnp.nan, 3.0])
762+
t = jnp.array([1.0, 2.0, 3.0])
763+
assert comparator(r, s) # NaN == NaN
764+
assert not comparator(r, t)
765+
766+
# Test arrays with infinity values
767+
u = jnp.array([1.0, jnp.inf, 3.0])
768+
v = jnp.array([1.0, jnp.inf, 3.0])
769+
w = jnp.array([1.0, -jnp.inf, 3.0])
770+
assert comparator(u, v)
771+
assert not comparator(u, w)
772+
773+
# Test complex arrays
774+
x = jnp.array([1+2j, 3+4j])
775+
y = jnp.array([1+2j, 3+4j])
776+
z = jnp.array([1+2j, 3+5j])
777+
assert comparator(x, y)
778+
assert not comparator(x, z)
779+
780+
# Test boolean arrays
781+
aa = jnp.array([True, False, True])
782+
bb = jnp.array([True, False, True])
783+
cc = jnp.array([True, True, True])
784+
assert comparator(aa, bb)
785+
assert not comparator(aa, cc)
786+
787+
713788
def test_returns():
714789
a = Success(5)
715790
b = Success(5)

tests/test_instrument_tests.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,7 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
129129
codeflash_test_index = codeflash_wrap.index[test_id]
130130
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
131131
"""
132-
if sys.version_info < (3, 12):
133-
expected += """test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}"
134-
"""
135-
else:
136-
expected += """test_stdout_tag = f'{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}'
132+
expected += """test_stdout_tag = f'{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}'
137133
"""
138134
expected += """print(f'!$######{{test_stdout_tag}}######$!')
139135
exception = None
@@ -189,9 +185,9 @@ def test_sort(self):
189185
)
190186
os.chdir(original_cwd)
191187
assert success
192-
assert new_test == expected.format(
188+
assert new_test.replace('"', "'") == expected.format(
193189
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
194-
)
190+
).replace('"', "'")
195191

196192

197193
def test_perfinjector_only_replay_test() -> None:
@@ -233,11 +229,7 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
233229
codeflash_test_index = codeflash_wrap.index[test_id]
234230
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
235231
"""
236-
if sys.version_info < (3, 12):
237-
expected += """test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}"
238-
"""
239-
else:
240-
expected += """test_stdout_tag = f'{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}'
232+
expected += """test_stdout_tag = f'{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}'
241233
"""
242234
expected += """print(f'!$######{{test_stdout_tag}}######$!')
243235
exception = None
@@ -289,9 +281,9 @@ def test_prepare_image_for_yolo():
289281
)
290282
os.chdir(original_cwd)
291283
assert success
292-
assert new_test == expected.format(
284+
assert new_test.replace('"', "'") == expected.format(
293285
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
294-
)
286+
).replace('"', "'")
295287

296288

297289
def test_perfinjector_bubble_sort_results() -> None:

0 commit comments

Comments
 (0)