|
| 1 | +import re |
| 2 | + |
| 3 | +import libcst as cst |
| 4 | + |
| 5 | +from codeflash.cli_cmds.console import logger |
| 6 | +from codeflash.code_utils.time_utils import format_time |
| 7 | +from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults |
| 8 | + |
| 9 | + |
| 10 | +def remove_functions_from_generated_tests( |
| 11 | + generated_tests: GeneratedTestsList, test_functions_to_remove: list[str] |
| 12 | +) -> GeneratedTestsList: |
| 13 | + new_generated_tests = [] |
| 14 | + for generated_test in generated_tests.generated_tests: |
| 15 | + for test_function in test_functions_to_remove: |
| 16 | + function_pattern = re.compile( |
| 17 | + rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)", |
| 18 | + re.DOTALL, |
| 19 | + ) |
| 20 | + |
| 21 | + match = function_pattern.search(generated_test.generated_original_test_source) |
| 22 | + |
| 23 | + if match is None or "@pytest.mark.parametrize" in match.group(0): |
| 24 | + continue |
| 25 | + |
| 26 | + generated_test.generated_original_test_source = function_pattern.sub( |
| 27 | + "", generated_test.generated_original_test_source |
| 28 | + ) |
| 29 | + |
| 30 | + new_generated_tests.append(generated_test) |
| 31 | + |
| 32 | + return GeneratedTestsList(generated_tests=new_generated_tests) |
| 33 | + |
| 34 | + |
| 35 | +def add_runtime_comments_to_generated_tests( |
| 36 | + generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults |
| 37 | +) -> GeneratedTestsList: |
| 38 | + """Add runtime performance comments to function calls in generated tests.""" |
| 39 | + # Create dictionaries for fast lookup of runtime data |
| 40 | + original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case() |
| 41 | + optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case() |
| 42 | + |
| 43 | + class RuntimeCommentTransformer(cst.CSTTransformer): |
| 44 | + def __init__(self) -> None: |
| 45 | + self.in_test_function = False |
| 46 | + self.current_test_name: str | None = None |
| 47 | + |
| 48 | + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: |
| 49 | + if node.name.value.startswith("test_"): |
| 50 | + self.in_test_function = True |
| 51 | + self.current_test_name = node.name.value |
| 52 | + else: |
| 53 | + self.in_test_function = False |
| 54 | + self.current_test_name = None |
| 55 | + |
| 56 | + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: |
| 57 | + if original_node.name.value.startswith("test_"): |
| 58 | + self.in_test_function = False |
| 59 | + self.current_test_name = None |
| 60 | + return updated_node |
| 61 | + |
| 62 | + def leave_SimpleStatementLine( |
| 63 | + self, |
| 64 | + original_node: cst.SimpleStatementLine, # noqa: ARG002 |
| 65 | + updated_node: cst.SimpleStatementLine, |
| 66 | + ) -> cst.SimpleStatementLine: |
| 67 | + if not self.in_test_function or not self.current_test_name: |
| 68 | + return updated_node |
| 69 | + |
| 70 | + # Look for assignment statements that assign to codeflash_output |
| 71 | + # Handle both single statements and multiple statements on one line |
| 72 | + codeflash_assignment_found = False |
| 73 | + for stmt in updated_node.body: |
| 74 | + if isinstance(stmt, cst.Assign) and ( |
| 75 | + len(stmt.targets) == 1 |
| 76 | + and isinstance(stmt.targets[0].target, cst.Name) |
| 77 | + and stmt.targets[0].target.value == "codeflash_output" |
| 78 | + ): |
| 79 | + codeflash_assignment_found = True |
| 80 | + break |
| 81 | + |
| 82 | + if codeflash_assignment_found: |
| 83 | + # Find matching test cases by looking for this test function name in the test results |
| 84 | + matching_original_times = [] |
| 85 | + matching_optimized_times = [] |
| 86 | + |
| 87 | + for invocation_id, runtimes in original_runtime_by_test.items(): |
| 88 | + if invocation_id.test_function_name == self.current_test_name: |
| 89 | + matching_original_times.extend(runtimes) |
| 90 | + |
| 91 | + for invocation_id, runtimes in optimized_runtime_by_test.items(): |
| 92 | + if invocation_id.test_function_name == self.current_test_name: |
| 93 | + matching_optimized_times.extend(runtimes) |
| 94 | + |
| 95 | + if matching_original_times and matching_optimized_times: |
| 96 | + original_time = min(matching_original_times) |
| 97 | + optimized_time = min(matching_optimized_times) |
| 98 | + |
| 99 | + # Create the runtime comment |
| 100 | + comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}" |
| 101 | + |
| 102 | + # Add comment to the trailing whitespace |
| 103 | + new_trailing_whitespace = cst.TrailingWhitespace( |
| 104 | + whitespace=cst.SimpleWhitespace(" "), |
| 105 | + comment=cst.Comment(comment_text), |
| 106 | + newline=updated_node.trailing_whitespace.newline, |
| 107 | + ) |
| 108 | + |
| 109 | + return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace) |
| 110 | + |
| 111 | + return updated_node |
| 112 | + |
| 113 | + # Process each generated test |
| 114 | + modified_tests = [] |
| 115 | + for test in generated_tests.generated_tests: |
| 116 | + try: |
| 117 | + # Parse the test source code |
| 118 | + tree = cst.parse_module(test.generated_original_test_source) |
| 119 | + |
| 120 | + # Transform the tree to add runtime comments |
| 121 | + transformer = RuntimeCommentTransformer() |
| 122 | + modified_tree = tree.visit(transformer) |
| 123 | + |
| 124 | + # Convert back to source code |
| 125 | + modified_source = modified_tree.code |
| 126 | + |
| 127 | + # Create a new GeneratedTests object with the modified source |
| 128 | + modified_test = GeneratedTests( |
| 129 | + generated_original_test_source=modified_source, |
| 130 | + instrumented_behavior_test_source=test.instrumented_behavior_test_source, |
| 131 | + instrumented_perf_test_source=test.instrumented_perf_test_source, |
| 132 | + behavior_file_path=test.behavior_file_path, |
| 133 | + perf_file_path=test.perf_file_path, |
| 134 | + ) |
| 135 | + modified_tests.append(modified_test) |
| 136 | + except Exception as e: |
| 137 | + # If parsing fails, keep the original test |
| 138 | + logger.debug(f"Failed to add runtime comments to test: {e}") |
| 139 | + modified_tests.append(test) |
| 140 | + |
| 141 | + return GeneratedTestsList(generated_tests=modified_tests) |
0 commit comments