diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 8e50b1d71..abbcb68c1 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -207,23 +207,41 @@ def add_runtime_comments_to_generated_tests( def remove_functions_from_generated_tests( generated_tests: GeneratedTestsList, test_functions_to_remove: list[str] ) -> GeneratedTestsList: + # Pre-compile patterns for all function names to remove + function_patterns = _compile_function_patterns(test_functions_to_remove) new_generated_tests = [] - for generated_test in generated_tests.generated_tests: - for test_function in test_functions_to_remove: - function_pattern = re.compile( - rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)", - re.DOTALL, - ) - - match = function_pattern.search(generated_test.generated_original_test_source) - - if match is None or "@pytest.mark.parametrize" in match.group(0): - continue - - generated_test.generated_original_test_source = function_pattern.sub( - "", generated_test.generated_original_test_source - ) + for generated_test in generated_tests.generated_tests: + source = generated_test.generated_original_test_source + + # Apply all patterns without redundant searches + for pattern in function_patterns: + # Use finditer and sub only if necessary to avoid unnecessary .search()/.sub() calls + for match in pattern.finditer(source): + # Skip if "@pytest.mark.parametrize" present + # Only the matched function's code is targeted + if "@pytest.mark.parametrize" in match.group(0): + continue + # Remove function from source + # If match, remove the function by substitution in the source + # Replace using start/end indices for efficiency + start, end = match.span() + source = source[:start] + source[end:] + # After removal, break since .finditer() is from left to right, and only one match expected per function in source + break + + generated_test.generated_original_test_source = source new_generated_tests.append(generated_test) return GeneratedTestsList(generated_tests=new_generated_tests) + + +# Pre-compile all function removal regexes upfront for efficiency. +def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.Pattern[str]]: + return [ + re.compile( + rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(func)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)", + re.DOTALL, + ) + for func in test_functions_to_remove + ]