diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 66ed7e2b4..1bcf4e47e 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,5 +1,4 @@ import sys -from collections import defaultdict from codeflash.cli_cmds.console import logger from codeflash.models.models import FunctionTestInvocation, TestResults, TestType, VerificationType @@ -138,7 +137,6 @@ def _compare_hypothesis_tests_semantic(original_hypothesis: list, candidate_hypo not how many examples Hypothesis generated. """ - # Group by test function (excluding loop index and iteration_id from comparison) def get_test_key(test_result: FunctionTestInvocation) -> tuple[str, str, str, str]: """Get unique key for a Hypothesis test function.""" return ( @@ -148,38 +146,39 @@ def get_test_key(test_result: FunctionTestInvocation) -> tuple[str, str, str, st test_result.id.function_getting_tested, ) - # Group original results by test function - original_by_func = defaultdict(list) + # Group by test function and simultaneously collect failure flag and example count + orig_by_func = {} for result in original_hypothesis: - original_by_func[get_test_key(result)].append(result) + test_key = get_test_key(result) + group = orig_by_func.setdefault(test_key, [0, False]) # [count, had_failure] + group[0] += 1 + if not result.did_pass: + group[1] = True - # Group candidate results by test function - candidate_by_func = defaultdict(list) + cand_by_func = {} for result in candidate_hypothesis: - candidate_by_func[get_test_key(result)].append(result) + test_key = get_test_key(result) + group = cand_by_func.setdefault(test_key, [0, False]) # [count, had_failure] + group[0] += 1 + if not result.did_pass: + group[1] = True - # Log summary statistics - orig_total_examples = sum(len(examples) for examples in original_by_func.values()) - cand_total_examples = sum(len(examples) for examples in candidate_by_func.values()) + orig_total_examples = sum(group[0] for group in orig_by_func.values()) + cand_total_examples = sum(group[0] for group in cand_by_func.values()) logger.debug( - f"Hypothesis comparison: Original={len(original_by_func)} test functions ({orig_total_examples} examples), " - f"Candidate={len(candidate_by_func)} test functions ({cand_total_examples} examples)" + f"Hypothesis comparison: Original={len(orig_by_func)} test functions ({orig_total_examples} examples), " + f"Candidate={len(cand_by_func)} test functions ({cand_total_examples} examples)" ) - for test_key in original_by_func: - if test_key not in candidate_by_func: + # Compare only for test_keys present in original + for test_key, (orig_count, orig_had_failure) in orig_by_func.items(): + cand_group = cand_by_func.get(test_key) + if cand_group is None: continue # Already handled above - orig_examples = original_by_func[test_key] - cand_examples = candidate_by_func[test_key] + cand_had_failure = cand_group[1] - # Check if any original example failed - orig_had_failure = any(not ex.did_pass for ex in orig_examples) - cand_had_failure = any(not ex.did_pass for ex in cand_examples) - - # If original had failures, candidate must also have failures (or be missing, already handled) - # If original passed, candidate must pass (but can have different example counts) if orig_had_failure != cand_had_failure: logger.debug( f"Hypothesis test function behavior mismatch: {test_key} "