Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,41 @@ def add_runtime_comments_to_generated_tests(
def remove_functions_from_generated_tests(
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
) -> GeneratedTestsList:
# Pre-compile patterns for all function names to remove
function_patterns = _compile_function_patterns(test_functions_to_remove)
new_generated_tests = []
for generated_test in generated_tests.generated_tests:
for test_function in test_functions_to_remove:
function_pattern = re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)",
re.DOTALL,
)

match = function_pattern.search(generated_test.generated_original_test_source)

if match is None or "@pytest.mark.parametrize" in match.group(0):
continue

generated_test.generated_original_test_source = function_pattern.sub(
"", generated_test.generated_original_test_source
)

for generated_test in generated_tests.generated_tests:
source = generated_test.generated_original_test_source

# Apply all patterns without redundant searches
for pattern in function_patterns:
# Use finditer and sub only if necessary to avoid unnecessary .search()/.sub() calls
for match in pattern.finditer(source):
# Skip if "@pytest.mark.parametrize" present
# Only the matched function's code is targeted
if "@pytest.mark.parametrize" in match.group(0):
continue
# Remove function from source
# If match, remove the function by substitution in the source
# Replace using start/end indices for efficiency
start, end = match.span()
source = source[:start] + source[end:]
# After removal, break since .finditer() is from left to right, and only one match expected per function in source
break

generated_test.generated_original_test_source = source
new_generated_tests.append(generated_test)

return GeneratedTestsList(generated_tests=new_generated_tests)


# Pre-compile all function removal regexes upfront for efficiency.
def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.Pattern[str]]:
return [
re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(func)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)",
re.DOTALL,
)
for func in test_functions_to_remove
]
Loading