Skip to content

Commit a7f8816

Browse files
fix tests
1 parent 168118a commit a7f8816

File tree

7 files changed

+61
-30
lines changed

7 files changed

+61
-30
lines changed

codeflash/models/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,8 @@ def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Option
513513
return None
514514

515515
def get_src_code(self, test_path: Path) -> Optional[str]:
516+
if not test_path.exists():
517+
return None
516518
test_src = test_path.read_text(encoding="utf-8")
517519
module_node = cst.parse_module(test_src)
518520

codeflash/verification/equivalence.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
from __future__ import annotations
2+
13
import sys
24
from dataclasses import dataclass
35
from enum import Enum
6+
from typing import TYPE_CHECKING, Optional
47

58
from codeflash.cli_cmds.console import logger
69
from codeflash.models.models import TestResults, TestType, VerificationType
710
from codeflash.verification.comparator import comparator
811

12+
if TYPE_CHECKING:
13+
from codeflash.models.models import TestResults
14+
915
INCREASED_RECURSION_LIMIT = 5000
1016

1117

@@ -19,10 +25,10 @@ class TestDiffScope(Enum):
1925
@dataclass
2026
class TestDiff:
2127
scope: TestDiffScope
22-
test_src_code: str
2328
pytest_error: str
2429
original_value: any
2530
candidate_value: any
31+
test_src_code: Optional[str] = None
2632

2733

2834
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:

tests/test_codeflash_capture.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,8 @@ def __init__(self, x=2):
502502
pytest_max_loops=1,
503503
testing_time=0.1,
504504
)
505-
assert compare_test_results(test_results, test_results2)
505+
match, _ = compare_test_results(test_results, test_results2)
506+
assert match
506507

507508
finally:
508509
test_path.unlink(missing_ok=True)
@@ -626,7 +627,8 @@ def __init__(self, *args, **kwargs):
626627
testing_time=0.1,
627628
)
628629

629-
assert compare_test_results(test_results, results2)
630+
match, _ = compare_test_results(test_results, results2)
631+
assert match
630632

631633
finally:
632634
test_path.unlink(missing_ok=True)
@@ -754,7 +756,8 @@ def __init__(self, x=2):
754756
testing_time=0.1,
755757
)
756758

757-
assert compare_test_results(test_results, test_results2)
759+
match, _ = compare_test_results(test_results, test_results2)
760+
assert match
758761
finally:
759762
test_path.unlink(missing_ok=True)
760763
sample_code_path.unlink(missing_ok=True)
@@ -902,7 +905,8 @@ def another_helper(self):
902905
testing_time=0.1,
903906
)
904907

905-
assert compare_test_results(test_results, results2)
908+
match, _ = compare_test_results(test_results, results2)
909+
assert match
906910

907911
finally:
908912
test_path.unlink(missing_ok=True)
@@ -1132,7 +1136,8 @@ def target_function(self):
11321136
)
11331137
# Remove instrumentation
11341138
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
1135-
assert not compare_test_results(test_results, mutated_test_results)
1139+
match, _ = compare_test_results(test_results, mutated_test_results)
1140+
assert not match
11361141

11371142
# This fto code stopped using a helper class. it should still pass
11381143
no_helper1_fto_code = """
@@ -1170,7 +1175,8 @@ def target_function(self):
11701175
)
11711176
# Remove instrumentation
11721177
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
1173-
assert compare_test_results(test_results, no_helper1_test_results)
1178+
match, _ = compare_test_results(test_results, no_helper1_test_results)
1179+
assert match
11741180

11751181
finally:
11761182
test_path.unlink(missing_ok=True)

tests/test_comparator.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,8 @@ def test_compare_results_fn():
11761176
)
11771177
)
11781178

1179-
assert compare_test_results(original_results, new_results_1)
1179+
match, _ = compare_test_results(original_results, new_results_1)
1180+
assert match
11801181

11811182
new_results_2 = TestResults()
11821183
new_results_2.add(
@@ -1199,7 +1200,8 @@ def test_compare_results_fn():
11991200
)
12001201
)
12011202

1202-
assert not compare_test_results(original_results, new_results_2)
1203+
match, _ = compare_test_results(original_results, new_results_2)
1204+
assert not match
12031205

12041206
new_results_3 = TestResults()
12051207
new_results_3.add(
@@ -1241,7 +1243,8 @@ def test_compare_results_fn():
12411243
)
12421244
)
12431245

1244-
assert compare_test_results(original_results, new_results_3)
1246+
match, _ = compare_test_results(original_results, new_results_3)
1247+
assert match
12451248

12461249
new_results_4 = TestResults()
12471250
new_results_4.add(
@@ -1264,7 +1267,8 @@ def test_compare_results_fn():
12641267
)
12651268
)
12661269

1267-
assert not compare_test_results(original_results, new_results_4)
1270+
match, _ = compare_test_results(original_results, new_results_4)
1271+
assert not match
12681272

12691273
new_results_5_baseline = TestResults()
12701274
new_results_5_baseline.add(
@@ -1308,7 +1312,8 @@ def test_compare_results_fn():
13081312
)
13091313
)
13101314

1311-
assert not compare_test_results(new_results_5_baseline, new_results_5_opt)
1315+
match, _ = compare_test_results(new_results_5_baseline, new_results_5_opt)
1316+
assert not match
13121317

13131318
new_results_6_baseline = TestResults()
13141319
new_results_6_baseline.add(
@@ -1352,9 +1357,11 @@ def test_compare_results_fn():
13521357
)
13531358
)
13541359

1355-
assert not compare_test_results(new_results_6_baseline, new_results_6_opt)
1360+
match, _ = compare_test_results(new_results_6_baseline, new_results_6_opt)
1361+
assert not match
13561362

1357-
assert not compare_test_results(TestResults(), TestResults())
1363+
match, _ = compare_test_results(TestResults(), TestResults())
1364+
assert not match
13581365

13591366

13601367
def test_exceptions():

tests/test_instrument_all_and_run.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ def test_sort():
223223
result: [0, 1, 2, 3, 4, 5]
224224
"""
225225
assert out_str == results2[0].stdout
226-
assert compare_test_results(test_results, results2)
226+
match, _ = compare_test_results(test_results, results2)
227+
assert match
227228
finally:
228229
fto_path.write_text(original_code, "utf-8")
229230
test_path.unlink(missing_ok=True)
@@ -368,7 +369,8 @@ def test_sort():
368369
assert test_results[1].return_value == ([0, 1, 2, 3, 4, 5],)
369370
out_str = """codeflash stdout : BubbleSorter.sorter() called\n"""
370371
assert test_results[1].stdout == out_str
371-
assert compare_test_results(test_results, test_results)
372+
match, _ = compare_test_results(test_results, test_results)
373+
assert match
372374
assert test_results[2].id.function_getting_tested == "BubbleSorter.__init__"
373375
assert test_results[2].id.test_function_name == "test_sort"
374376
assert test_results[2].did_pass
@@ -396,7 +398,8 @@ def test_sort():
396398
testing_time=0.1,
397399
)
398400

399-
assert compare_test_results(test_results, results2)
401+
match, _ = compare_test_results(test_results, results2)
402+
assert match
400403

401404
# Replace with optimized code that mutated instance attribute
402405
optimized_code = """
@@ -491,7 +494,8 @@ def sorter(self, arr):
491494
)
492495
assert new_test_results[3].runtime > 0
493496
assert new_test_results[3].did_pass
494-
assert not compare_test_results(test_results, new_test_results)
497+
match, _ = compare_test_results(test_results, new_test_results)
498+
assert not match
495499

496500
finally:
497501
fto_path.write_text(original_code, "utf-8")
@@ -630,7 +634,8 @@ def test_sort():
630634
out_str = """codeflash stdout : BubbleSorter.sorter_classmethod() called
631635
"""
632636
assert test_results[0].stdout == out_str
633-
assert compare_test_results(test_results, test_results)
637+
match, _ = compare_test_results(test_results, test_results)
638+
assert match
634639

635640
assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_classmethod"
636641
assert test_results[1].id.iteration_id == "4_0"
@@ -655,7 +660,8 @@ def test_sort():
655660
testing_time=0.1,
656661
)
657662

658-
assert compare_test_results(test_results, results2)
663+
match, _ = compare_test_results(test_results, results2)
664+
assert match
659665

660666
finally:
661667
fto_path.write_text(original_code, "utf-8")
@@ -794,7 +800,8 @@ def test_sort():
794800
out_str = """codeflash stdout : BubbleSorter.sorter_staticmethod() called
795801
"""
796802
assert test_results[0].stdout == out_str
797-
assert compare_test_results(test_results, test_results)
803+
match, _ = compare_test_results(test_results, test_results)
804+
assert match
798805

799806
assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_staticmethod"
800807
assert test_results[1].id.iteration_id == "4_0"
@@ -819,7 +826,8 @@ def test_sort():
819826
testing_time=0.1,
820827
)
821828

822-
assert compare_test_results(test_results, results2)
829+
match, _ = compare_test_results(test_results, results2)
830+
assert match
823831

824832
finally:
825833
fto_path.write_text(original_code, "utf-8")

tests/test_instrumentation_run_results_aiservice.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,10 @@ def sorter(self, arr):
221221
testing_time=0.1,
222222
)
223223
# assert test_results_mutated_attr[0].return_value[1]["self"].x == 1 TODO: add self as input to function
224-
assert compare_test_results(
224+
match, _ = compare_test_results(
225225
test_results, test_results_mutated_attr
226226
) # Without codeflash capture, the init state was not verified, and the results are verified as correct even with the attribute mutated
227-
227+
assert match
228228
assert test_results_mutated_attr[0].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
229229
finally:
230230
fto_path.write_text(original_code, "utf-8")
@@ -403,9 +403,10 @@ def sorter(self, arr):
403403
assert test_results_mutated_attr[0].return_value[0] == {"x": 1}
404404
assert test_results_mutated_attr[0].verification_type == VerificationType.INIT_STATE_FTO
405405
assert test_results_mutated_attr[0].stdout == ""
406-
assert not compare_test_results(
406+
match,_ = compare_test_results(
407407
test_results, test_results_mutated_attr
408408
) # The test should fail because the instance attribute was mutated
409+
assert not match
409410
# Replace with optimized code that did not mutate existing instance attribute, but added a new one
410411
optimized_code_new_attr = """
411412
import sys
@@ -457,9 +458,10 @@ def sorter(self, arr):
457458
assert test_results_new_attr[0].stdout == ""
458459
# assert test_results_new_attr[1].return_value[1]["self"].x == 0 TODO: add self as input
459460
# assert test_results_new_attr[1].return_value[1]["self"].y == 2 TODO: add self as input
460-
assert compare_test_results(
461+
match,_ = compare_test_results(
461462
test_results, test_results_new_attr
462463
) # The test should pass because the instance attribute was not mutated, only a new one was added
464+
assert match
463465
finally:
464466
fto_path.write_text(original_code, "utf-8")
465467
test_path.unlink(missing_ok=True)

tests/test_pickle_patcher.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,8 @@ def bubble_sort_with_unused_socket(data_container):
427427
testing_time=1.0,
428428
)
429429
assert len(optimized_test_results_unused_socket) == 1
430-
verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
431-
assert verification_result is True
430+
match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
431+
assert match
432432

433433
# Remove the previous instrumentation
434434
replay_test_path.write_text(original_replay_test_code)
@@ -517,8 +517,8 @@ def bubble_sort_with_used_socket(data_container):
517517
assert test_results_used_socket.test_results[0].did_pass is False
518518

519519
# Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined.
520-
assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False
521-
520+
match, _ = compare_test_results(test_results_used_socket, optimized_test_results_used_socket)
521+
assert not match
522522
finally:
523523
# cleanup
524524
output_file.unlink(missing_ok=True)

0 commit comments

Comments
 (0)