Skip to content

Commit 4866d82

Browse files
committed
cleanup
1 parent 572ac0e commit 4866d82

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

codeflash/verification/equivalence.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
from collections import defaultdict
23

34
from codeflash.cli_cmds.console import logger
45
from codeflash.models.models import TestResults, TestType, VerificationType
@@ -14,14 +15,47 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
1415
original_recursion_limit = sys.getrecursionlimit()
1516
if original_recursion_limit < INCREASED_RECURSION_LIMIT:
1617
sys.setrecursionlimit(INCREASED_RECURSION_LIMIT) # Increase recursion limit to avoid RecursionError
18+
19+
# Separate Hypothesis tests from other test types for semantic comparison
20+
# Hypothesis tests are always compared semantically (by test function, not example count)
21+
original_hypothesis = [
22+
r for r in original_results.test_results if r.test_type == TestType.HYPOTHESIS_TEST and r.loop_index == 1
23+
]
24+
candidate_hypothesis = [
25+
r for r in candidate_results.test_results if r.test_type == TestType.HYPOTHESIS_TEST and r.loop_index == 1
26+
]
27+
28+
# Compare Hypothesis tests semantically if any are present
29+
if original_hypothesis or candidate_hypothesis:
30+
logger.debug(
31+
f"Comparing Hypothesis tests: original={len(original_hypothesis)} examples, "
32+
f"candidate={len(candidate_hypothesis)} examples"
33+
)
34+
hypothesis_equal = _compare_hypothesis_tests_semantic(original_hypothesis, candidate_hypothesis)
35+
if not hypothesis_equal:
36+
logger.info("Hypothesis comparison failed")
37+
sys.setrecursionlimit(original_recursion_limit)
38+
return False
39+
logger.debug("Hypothesis comparison passed")
40+
1741
test_ids_superset = original_results.get_all_unique_invocation_loop_ids().union(
1842
set(candidate_results.get_all_unique_invocation_loop_ids())
1943
)
44+
logger.debug(f"Total test IDs in superset: {len(test_ids_superset)}")
2045
are_equal: bool = True
2146
did_all_timeout: bool = True
2247
for test_id in test_ids_superset:
2348
original_test_result = original_results.get_by_unique_invocation_loop_id(test_id)
2449
cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id)
50+
51+
# Skip Hypothesis tests - already compared semantically above
52+
if original_test_result and original_test_result.test_type == TestType.HYPOTHESIS_TEST:
53+
did_all_timeout = False # Hypothesis tests are checked separately, not timed out
54+
continue
55+
if cdd_test_result and cdd_test_result.test_type == TestType.HYPOTHESIS_TEST:
56+
did_all_timeout = False
57+
continue
58+
2559
if cdd_test_result is not None and original_test_result is None:
2660
continue
2761
# If helper function instance_state verification is not present, that's ok. continue
@@ -33,6 +67,11 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
3367
continue
3468
if original_test_result is None or cdd_test_result is None:
3569
are_equal = False
70+
logger.debug(
71+
f"Test result mismatch: test_id={test_id}, "
72+
f"original_present={original_test_result is not None}, "
73+
f"candidate_present={cdd_test_result is not None}"
74+
)
3675
break
3776
did_all_timeout = did_all_timeout and original_test_result.timed_out
3877
if original_test_result.timed_out:
@@ -80,5 +119,89 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
80119
break
81120
sys.setrecursionlimit(original_recursion_limit)
82121
if did_all_timeout:
122+
logger.debug("Comparison failed: all tests timed out")
83123
return False
124+
logger.debug(f"Final comparison result: are_equal={are_equal}")
84125
return are_equal
126+
127+
128+
def _compare_hypothesis_tests_semantic(original_hypothesis: list, candidate_hypothesis: list) -> bool:
129+
"""Compare Hypothesis tests by test function, not by example count.
130+
131+
Hypothesis can generate different numbers of examples between runs due to:
132+
- Timing differences
133+
- Early stopping
134+
- Shrinking behavior
135+
- Performance differences
136+
137+
What matters is whether the test functions themselves pass or fail,
138+
not how many examples Hypothesis generated.
139+
"""
140+
141+
# Group by test function (excluding loop index and iteration_id from comparison)
142+
def get_test_key(test_result):
143+
"""Get unique key for a Hypothesis test function."""
144+
return (
145+
test_result.id.test_module_path,
146+
test_result.id.test_class_name,
147+
test_result.id.test_function_name,
148+
test_result.id.function_getting_tested,
149+
)
150+
151+
# Group original results by test function
152+
original_by_func = defaultdict(list)
153+
for result in original_hypothesis:
154+
original_by_func[get_test_key(result)].append(result)
155+
156+
# Group candidate results by test function
157+
candidate_by_func = defaultdict(list)
158+
for result in candidate_hypothesis:
159+
candidate_by_func[get_test_key(result)].append(result)
160+
161+
# Log summary statistics
162+
orig_total_examples = sum(len(examples) for examples in original_by_func.values())
163+
cand_total_examples = sum(len(examples) for examples in candidate_by_func.values())
164+
165+
logger.debug(
166+
f"Hypothesis comparison: Original={len(original_by_func)} test functions ({orig_total_examples} examples), "
167+
f"Candidate={len(candidate_by_func)} test functions ({cand_total_examples} examples)"
168+
)
169+
170+
# Check if all test functions in original are present in candidate
171+
missing_funcs = set(original_by_func.keys()) - set(candidate_by_func.keys())
172+
if missing_funcs:
173+
logger.warning(
174+
f"Hypothesis test functions missing in candidate: {len(missing_funcs)} functions. "
175+
f"First missing: {missing_funcs.__iter__().__next__()}"
176+
)
177+
return False
178+
179+
# Compare each test function's results
180+
for test_key in original_by_func:
181+
if test_key not in candidate_by_func:
182+
continue # Already handled above
183+
184+
orig_examples = original_by_func[test_key]
185+
cand_examples = candidate_by_func[test_key]
186+
187+
# Check if any original example failed
188+
orig_had_failure = any(not ex.did_pass for ex in orig_examples)
189+
cand_had_failure = any(not ex.did_pass for ex in cand_examples)
190+
191+
# If original had failures, candidate must also have failures (or be missing, already handled)
192+
# If original passed, candidate must pass (but can have different example counts)
193+
if orig_had_failure != cand_had_failure:
194+
logger.debug(
195+
f"Hypothesis test function behavior mismatch: {test_key} "
196+
f"(original_failed={orig_had_failure}, candidate_failed={cand_had_failure})"
197+
)
198+
return False
199+
200+
if abs(len(orig_examples) - len(cand_examples)) > 10:
201+
logger.info(
202+
f"Hypothesis test '{test_key[2]}': example counts differ "
203+
f"(original={len(orig_examples)}, candidate={len(cand_examples)}). "
204+
f"This is expected when code performance changes."
205+
)
206+
207+
return True

0 commit comments

Comments
 (0)