Skip to content

Commit a6e8cdd

Browse files
authored
Merge pull request #858 from codeflash-ai/codeflash/optimize-pr857-2025-10-26T20.37.41
⚡️ Speed up function `_compare_hypothesis_tests_semantic` by 32% in PR #857 (`feat/hypothesis-tests`)
2 parents 8fb7c1e + c9f6483 commit a6e8cdd

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

codeflash/verification/equivalence.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import sys
2-
from collections import defaultdict
32

43
from codeflash.cli_cmds.console import logger
54
from codeflash.models.models import FunctionTestInvocation, TestResults, TestType, VerificationType
@@ -138,7 +137,6 @@ def _compare_hypothesis_tests_semantic(original_hypothesis: list, candidate_hypo
138137
not how many examples Hypothesis generated.
139138
"""
140139

141-
# Group by test function (excluding loop index and iteration_id from comparison)
142140
def get_test_key(test_result: FunctionTestInvocation) -> tuple[str, str, str, str]:
143141
"""Get unique key for a Hypothesis test function."""
144142
return (
@@ -148,38 +146,39 @@ def get_test_key(test_result: FunctionTestInvocation) -> tuple[str, str, str, st
148146
test_result.id.function_getting_tested,
149147
)
150148

151-
# Group original results by test function
152-
original_by_func = defaultdict(list)
149+
# Group by test function and simultaneously collect failure flag and example count
150+
orig_by_func = {}
153151
for result in original_hypothesis:
154-
original_by_func[get_test_key(result)].append(result)
152+
test_key = get_test_key(result)
153+
group = orig_by_func.setdefault(test_key, [0, False]) # [count, had_failure]
154+
group[0] += 1
155+
if not result.did_pass:
156+
group[1] = True
155157

156-
# Group candidate results by test function
157-
candidate_by_func = defaultdict(list)
158+
cand_by_func = {}
158159
for result in candidate_hypothesis:
159-
candidate_by_func[get_test_key(result)].append(result)
160+
test_key = get_test_key(result)
161+
group = cand_by_func.setdefault(test_key, [0, False]) # [count, had_failure]
162+
group[0] += 1
163+
if not result.did_pass:
164+
group[1] = True
160165

161-
# Log summary statistics
162-
orig_total_examples = sum(len(examples) for examples in original_by_func.values())
163-
cand_total_examples = sum(len(examples) for examples in candidate_by_func.values())
166+
orig_total_examples = sum(group[0] for group in orig_by_func.values())
167+
cand_total_examples = sum(group[0] for group in cand_by_func.values())
164168

165169
logger.debug(
166-
f"Hypothesis comparison: Original={len(original_by_func)} test functions ({orig_total_examples} examples), "
167-
f"Candidate={len(candidate_by_func)} test functions ({cand_total_examples} examples)"
170+
f"Hypothesis comparison: Original={len(orig_by_func)} test functions ({orig_total_examples} examples), "
171+
f"Candidate={len(cand_by_func)} test functions ({cand_total_examples} examples)"
168172
)
169173

170-
for test_key in original_by_func:
171-
if test_key not in candidate_by_func:
174+
# Compare only for test_keys present in original
175+
for test_key, (orig_count, orig_had_failure) in orig_by_func.items():
176+
cand_group = cand_by_func.get(test_key)
177+
if cand_group is None:
172178
continue # Already handled above
173179

174-
orig_examples = original_by_func[test_key]
175-
cand_examples = candidate_by_func[test_key]
180+
cand_had_failure = cand_group[1]
176181

177-
# Check if any original example failed
178-
orig_had_failure = any(not ex.did_pass for ex in orig_examples)
179-
cand_had_failure = any(not ex.did_pass for ex in cand_examples)
180-
181-
# If original had failures, candidate must also have failures (or be missing, already handled)
182-
# If original passed, candidate must pass (but can have different example counts)
183182
if orig_had_failure != cand_had_failure:
184183
logger.debug(
185184
f"Hypothesis test function behavior mismatch: {test_key} "

0 commit comments

Comments
 (0)