From c9f64830c394fb6b6fe831a2d916d2340919ac9d Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Sun, 26 Oct 2025 20:37:44 +0000 Subject: [PATCH] Optimize _compare_hypothesis_tests_semantic The optimized code achieves a **32% speedup** by eliminating redundant data structures and reducing iteration overhead through two key optimizations: **1. Single-pass aggregation instead of list accumulation:** - **Original**: Uses `defaultdict(list)` to collect all `FunctionTestInvocation` objects per test function, then later iterates through these lists to compute failure flags with `any(not ex.did_pass for ex in orig_examples)` - **Optimized**: Uses plain dicts with 2-element lists `[count, had_failure]` to track both example count and failure status in a single pass, eliminating the need to store individual test objects or re-scan them **2. Reduced memory allocation and access patterns:** - **Original**: Creates and stores complete lists of test objects (up to 9,458 objects in large test cases), then performs expensive `any()` operations over these lists - **Optimized**: Uses compact 2-item lists per test function, avoiding object accumulation and expensive linear scans The line profiler shows the key performance gains: - Lines with `any(not ex.did_pass...)` in original (10.1% and 10.2% of total time) are completely eliminated - The `setdefault()` operations replace the more expensive `defaultdict(list).append()` calls - Overall reduction from storing ~9,458 objects to just tracking summary statistics **Best performance gains** occur in test cases with: - **Large numbers of examples per test function** (up to 105% faster for `test_large_scale_all_fail`) - **Many distinct test functions** (up to 75% faster for `test_large_scale_some_failures`) - **Mixed pass/fail scenarios** where the original's `any()` operations were most expensive The optimization maintains identical behavior while dramatically reducing both memory usage and computational complexity from O(examples) to O(1) per test function group. --- codeflash/verification/equivalence.py | 45 +++++++++++++-------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 66ed7e2b4..1bcf4e47e 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,5 +1,4 @@ import sys -from collections import defaultdict from codeflash.cli_cmds.console import logger from codeflash.models.models import FunctionTestInvocation, TestResults, TestType, VerificationType @@ -138,7 +137,6 @@ def _compare_hypothesis_tests_semantic(original_hypothesis: list, candidate_hypo not how many examples Hypothesis generated. """ - # Group by test function (excluding loop index and iteration_id from comparison) def get_test_key(test_result: FunctionTestInvocation) -> tuple[str, str, str, str]: """Get unique key for a Hypothesis test function.""" return ( @@ -148,38 +146,39 @@ def get_test_key(test_result: FunctionTestInvocation) -> tuple[str, str, str, st test_result.id.function_getting_tested, ) - # Group original results by test function - original_by_func = defaultdict(list) + # Group by test function and simultaneously collect failure flag and example count + orig_by_func = {} for result in original_hypothesis: - original_by_func[get_test_key(result)].append(result) + test_key = get_test_key(result) + group = orig_by_func.setdefault(test_key, [0, False]) # [count, had_failure] + group[0] += 1 + if not result.did_pass: + group[1] = True - # Group candidate results by test function - candidate_by_func = defaultdict(list) + cand_by_func = {} for result in candidate_hypothesis: - candidate_by_func[get_test_key(result)].append(result) + test_key = get_test_key(result) + group = cand_by_func.setdefault(test_key, [0, False]) # [count, had_failure] + group[0] += 1 + if not result.did_pass: + group[1] = True - # Log summary statistics - orig_total_examples = sum(len(examples) for examples in original_by_func.values()) - cand_total_examples = sum(len(examples) for examples in candidate_by_func.values()) + orig_total_examples = sum(group[0] for group in orig_by_func.values()) + cand_total_examples = sum(group[0] for group in cand_by_func.values()) logger.debug( - f"Hypothesis comparison: Original={len(original_by_func)} test functions ({orig_total_examples} examples), " - f"Candidate={len(candidate_by_func)} test functions ({cand_total_examples} examples)" + f"Hypothesis comparison: Original={len(orig_by_func)} test functions ({orig_total_examples} examples), " + f"Candidate={len(cand_by_func)} test functions ({cand_total_examples} examples)" ) - for test_key in original_by_func: - if test_key not in candidate_by_func: + # Compare only for test_keys present in original + for test_key, (orig_count, orig_had_failure) in orig_by_func.items(): + cand_group = cand_by_func.get(test_key) + if cand_group is None: continue # Already handled above - orig_examples = original_by_func[test_key] - cand_examples = candidate_by_func[test_key] + cand_had_failure = cand_group[1] - # Check if any original example failed - orig_had_failure = any(not ex.did_pass for ex in orig_examples) - cand_had_failure = any(not ex.did_pass for ex in cand_examples) - - # If original had failures, candidate must also have failures (or be missing, already handled) - # If original passed, candidate must pass (but can have different example counts) if orig_had_failure != cand_had_failure: logger.debug( f"Hypothesis test function behavior mismatch: {test_key} "