Skip to content

Commit c9b4b29

Browse files
authored
Merge pull request codeflash-ai#537 from codeflash-ai/runtime-fixes-jul10
Faster and more robust Runtime Comments annotations for generated tests
2 parents 6add2dc + 47fa9c4 commit c9b4b29

File tree

4 files changed

+567
-349
lines changed

4 files changed

+567
-349
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 152 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import os
55
import re
66
from pathlib import Path
7-
from textwrap import dedent
87
from typing import TYPE_CHECKING
98

109
import libcst as cst
10+
from libcst import MetadataWrapper
11+
from libcst.metadata import PositionProvider
1112

1213
from codeflash.cli_cmds.console import logger
1314
from codeflash.code_utils.time_utils import format_perf, format_time
@@ -16,230 +17,167 @@
1617

1718
if TYPE_CHECKING:
1819
from codeflash.models.models import InvocationId
19-
from codeflash.verification.verification_utils import TestConfig
2020

2121

22-
def remove_functions_from_generated_tests(
23-
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
24-
) -> GeneratedTestsList:
25-
new_generated_tests = []
26-
for generated_test in generated_tests.generated_tests:
27-
for test_function in test_functions_to_remove:
28-
function_pattern = re.compile(
29-
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
30-
re.DOTALL,
31-
)
32-
33-
match = function_pattern.search(generated_test.generated_original_test_source)
34-
35-
if match is None or "@pytest.mark.parametrize" in match.group(0):
36-
continue
37-
38-
generated_test.generated_original_test_source = function_pattern.sub(
39-
"", generated_test.generated_original_test_source
40-
)
22+
class CommentMapper(ast.NodeVisitor):
23+
def __init__(
24+
self, test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
25+
) -> None:
26+
self.results: dict[int, str] = {}
27+
self.test: GeneratedTests = test
28+
self.original_runtimes = original_runtimes
29+
self.optimized_runtimes = optimized_runtimes
30+
self.abs_path = test.behavior_file_path.with_suffix("")
31+
self.context_stack: list[str] = []
4132

42-
new_generated_tests.append(generated_test)
33+
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
34+
self.context_stack.append(node.name)
35+
for inner_node in ast.walk(node):
36+
if isinstance(inner_node, ast.FunctionDef):
37+
self.visit_FunctionDef(inner_node)
38+
self.context_stack.pop()
39+
return node
40+
41+
def get_comment(self, match_key: str) -> str:
42+
# calculate speedup and output comment
43+
original_time = self.original_runtimes[match_key]
44+
optimized_time = self.optimized_runtimes[match_key]
45+
perf_gain = format_perf(
46+
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
47+
)
48+
status = "slower" if optimized_time > original_time else "faster"
49+
# Create the runtime comment
50+
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
51+
52+
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
53+
self.context_stack.append(node.name)
54+
i = len(node.body) - 1
55+
test_qualified_name = ".".join(self.context_stack)
56+
key = test_qualified_name + "#" + str(self.abs_path)
57+
while i >= 0:
58+
line_node = node.body[i]
59+
if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)):
60+
j = len(line_node.body) - 1
61+
while j >= 0:
62+
compound_line_node: ast.stmt = line_node.body[j]
63+
internal_node: ast.AST
64+
for internal_node in ast.walk(compound_line_node):
65+
if isinstance(internal_node, (ast.stmt, ast.Assign)):
66+
inv_id = str(i) + "_" + str(j)
67+
match_key = key + "#" + inv_id
68+
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
69+
self.results[internal_node.lineno] = self.get_comment(match_key)
70+
j -= 1
71+
else:
72+
inv_id = str(i)
73+
match_key = key + "#" + inv_id
74+
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
75+
self.results[line_node.lineno] = self.get_comment(match_key)
76+
i -= 1
77+
self.context_stack.pop()
78+
return node
4379

44-
return GeneratedTestsList(generated_tests=new_generated_tests)
4580

81+
def get_fn_call_linenos(
82+
test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
83+
) -> dict[int, str]:
84+
line_comment_ast_mapper = CommentMapper(test, original_runtimes, optimized_runtimes)
85+
source_code = test.generated_original_test_source
86+
tree = ast.parse(source_code)
87+
line_comment_ast_mapper.visit(tree)
88+
return line_comment_ast_mapper.results
4689

47-
class CfoVisitor(ast.NodeVisitor):
48-
"""AST visitor that finds all assignments to a variable named 'codeflash_output'.
4990

50-
and reports their location relative to the function they're in.
51-
"""
91+
class CommentAdder(cst.CSTTransformer):
92+
"""Transformer that adds comments to specified lines."""
5293

53-
def __init__(self, function_name: str, source_code: str) -> None:
54-
self.source_lines = source_code.splitlines()
55-
self.name = function_name
56-
self.results: list[int] = [] # map actual line number to line number in ast
94+
# Declare metadata dependencies
95+
METADATA_DEPENDENCIES = (PositionProvider,)
5796

58-
def visit_Call(self, node): # type: ignore[no-untyped-def] # noqa: ANN201, ANN001
59-
"""Detect fn calls."""
60-
func_name = self._get_called_func_name(node.func) # type: ignore[no-untyped-call]
61-
if func_name == self.name:
62-
self.results.append(node.lineno - 1)
63-
self.generic_visit(node)
97+
def __init__(self, line_to_comments: dict[int, str]) -> None:
98+
"""Initialize the transformer with target line numbers.
6499
65-
def _get_called_func_name(self, node): # type: ignore[no-untyped-def] # noqa: ANN001, ANN202
66-
"""Return name of called fn."""
67-
if isinstance(node, ast.Name):
68-
return node.id
69-
if isinstance(node, ast.Attribute):
70-
return node.attr
71-
return None
100+
Args:
101+
line_to_comments: Mapping of line numbers (1-indexed) to comments
72102
103+
"""
104+
self.line_to_comments = line_to_comments
105+
super().__init__()
73106

74-
def find_codeflash_output_assignments(function_name: str, source_code: str) -> list[int]:
75-
tree = ast.parse(source_code)
76-
visitor = CfoVisitor(function_name, source_code)
77-
visitor.visit(tree)
78-
return visitor.results
107+
def leave_SimpleStatementLine(
108+
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
109+
) -> cst.SimpleStatementLine:
110+
"""Add comment to simple statement lines."""
111+
pos = self.get_metadata(PositionProvider, original_node)
79112

113+
if pos and pos.start.line in self.line_to_comments:
114+
# Create a comment with trailing whitespace
115+
comment = cst.TrailingWhitespace(
116+
whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment(self.line_to_comments[pos.start.line])
117+
)
80118

81-
class Finder(cst.CSTVisitor):
82-
def __init__(self, name: str) -> None:
83-
super().__init__()
84-
self.found = False
85-
self.name = name
119+
# Update the trailing whitespace of the line itself
120+
return updated_node.with_changes(trailing_whitespace=comment)
86121

87-
def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa : ANN001
88-
func_expr = call_node.func
89-
if isinstance(func_expr, cst.Name):
90-
if func_expr.value == self.name:
91-
self.found = True
92-
elif isinstance(func_expr, cst.Attribute): # noqa : SIM102
93-
if func_expr.attr.value == self.name:
94-
self.found = True
122+
return updated_node
95123

124+
def leave_SimpleStatementSuite(
125+
self, original_node: cst.SimpleStatementSuite, updated_node: cst.SimpleStatementSuite
126+
) -> cst.SimpleStatementSuite:
127+
"""Add comment to simple statement suites (e.g., after if/for/while)."""
128+
pos = self.get_metadata(PositionProvider, original_node)
96129

97-
# TODO: reduce for loops to one
98-
class RuntimeCommentTransformer(cst.CSTTransformer):
99-
def __init__(
100-
self,
101-
qualified_name: str,
102-
module: cst.Module,
103-
test: GeneratedTests,
104-
tests_root: Path,
105-
original_runtimes: dict[InvocationId, list[int]],
106-
optimized_runtimes: dict[InvocationId, list[int]],
107-
) -> None:
108-
super().__init__()
109-
self.test = test
110-
self.context_stack: list[str] = []
111-
self.tests_root = tests_root
112-
self.module = module
113-
self.cfo_locs: list[int] = []
114-
self.cfo_idx_loc_to_look_at: int = -1
115-
self.name = qualified_name.split(".")[-1]
116-
self.original_runtimes = original_runtimes
117-
self.optimized_runtimes = optimized_runtimes
118-
119-
def visit_ClassDef(self, node: cst.ClassDef) -> None:
120-
# Track when we enter a class
121-
self.context_stack.append(node.name.value)
130+
if pos and pos.start.line in self.line_to_comments:
131+
# Create a comment with trailing whitespace
132+
comment = cst.TrailingWhitespace(
133+
whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment(self.line_to_comments[pos.start.line])
134+
)
122135

123-
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
124-
# Pop the context when we leave a class
125-
self.context_stack.pop()
126-
return updated_node
136+
# Update the trailing whitespace of the suite
137+
return updated_node.with_changes(trailing_whitespace=comment)
127138

128-
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
129-
# convert function body to ast normalized string and find occurrences of codeflash_output
130-
body_code = dedent(self.module.code_for_node(node.body))
131-
normalized_body_code = ast.unparse(ast.parse(body_code))
132-
self.cfo_locs = sorted(
133-
find_codeflash_output_assignments(self.name, normalized_body_code)
134-
) # sorted in order we will encounter them
135-
self.cfo_idx_loc_to_look_at = -1
136-
self.context_stack.append(node.name.value)
137-
138-
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
139-
# Pop the context when we leave a function
140-
self.context_stack.pop()
141139
return updated_node
142140

143-
def leave_SimpleStatementLine(
144-
self,
145-
original_node: cst.SimpleStatementLine, # noqa: ARG002
146-
updated_node: cst.SimpleStatementLine,
147-
) -> cst.SimpleStatementLine:
148-
# Check if this statement line contains a call to self.name
149-
if self._contains_myfunc_call(updated_node): # type: ignore[no-untyped-call]
150-
# Find matching test cases by looking for this test function name in the test results
151-
self.cfo_idx_loc_to_look_at += 1
152-
matching_original_times = []
153-
matching_optimized_times = []
154-
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
155-
for invocation_id, runtimes in self.original_runtimes.items():
156-
# get position here and match in if condition
157-
qualified_name = (
158-
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
159-
if invocation_id.test_class_name
160-
else invocation_id.test_function_name
161-
)
162-
abs_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve()
163-
if (
164-
qualified_name == ".".join(self.context_stack)
165-
and abs_path in [self.test.behavior_file_path, self.test.perf_file_path]
166-
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
167-
):
168-
matching_original_times.extend(runtimes)
169-
170-
for invocation_id, runtimes in self.optimized_runtimes.items():
171-
# get position here and match in if condition
172-
qualified_name = (
173-
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
174-
if invocation_id.test_class_name
175-
else invocation_id.test_function_name
176-
)
177-
abs_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve()
178-
if (
179-
qualified_name == ".".join(self.context_stack)
180-
and abs_path in [self.test.behavior_file_path, self.test.perf_file_path]
181-
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
182-
):
183-
matching_optimized_times.extend(runtimes)
184-
185-
if matching_original_times and matching_optimized_times:
186-
original_time = min(matching_original_times)
187-
optimized_time = min(matching_optimized_times)
188-
if original_time != 0 and optimized_time != 0:
189-
perf_gain = format_perf(
190-
abs(
191-
performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time)
192-
* 100
193-
)
194-
)
195-
status = "slower" if optimized_time > original_time else "faster"
196-
# Create the runtime comment
197-
comment_text = (
198-
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
199-
)
200-
return updated_node.with_changes(
201-
trailing_whitespace=cst.TrailingWhitespace(
202-
whitespace=cst.SimpleWhitespace(" "),
203-
comment=cst.Comment(comment_text),
204-
newline=updated_node.trailing_whitespace.newline,
205-
)
206-
)
207-
return updated_node
208141

209-
def _contains_myfunc_call(self, node): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001
210-
"""Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
211-
finder = Finder(self.name)
212-
node.visit(finder)
213-
return finder.found
142+
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, int]:
143+
unique_inv_ids: dict[str, int] = {}
144+
for inv_id, runtimes in inv_id_runtimes.items():
145+
test_qualified_name = (
146+
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
147+
if inv_id.test_class_name
148+
else inv_id.test_function_name
149+
)
150+
abs_path = str(Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve().with_suffix(""))
151+
if "__unit_test_" not in abs_path:
152+
continue
153+
key = test_qualified_name + "#" + abs_path # type: ignore[operator]
154+
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
155+
cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr]
156+
match_key = key + "#" + cur_invid
157+
if match_key not in unique_inv_ids:
158+
unique_inv_ids[match_key] = 0
159+
unique_inv_ids[match_key] += min(runtimes)
160+
return unique_inv_ids
214161

215162

216163
def add_runtime_comments_to_generated_tests(
217-
qualified_name: str,
218-
test_cfg: TestConfig,
219164
generated_tests: GeneratedTestsList,
220165
original_runtimes: dict[InvocationId, list[int]],
221166
optimized_runtimes: dict[InvocationId, list[int]],
222167
) -> GeneratedTestsList:
223168
"""Add runtime performance comments to function calls in generated tests."""
224-
tests_root = test_cfg.tests_root
225-
169+
original_runtimes_dict = unique_inv_id(original_runtimes)
170+
optimized_runtimes_dict = unique_inv_id(optimized_runtimes)
226171
# Process each generated test
227172
modified_tests = []
228173
for test in generated_tests.generated_tests:
229174
try:
230-
# Parse the test source code
231175
tree = cst.parse_module(test.generated_original_test_source)
232-
# Transform the tree to add runtime comments
233-
# qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path
234-
transformer = RuntimeCommentTransformer(
235-
qualified_name, tree, test, tests_root, original_runtimes, optimized_runtimes
236-
)
237-
modified_tree = tree.visit(transformer)
238-
239-
# Convert back to source code
176+
wrapper = MetadataWrapper(tree)
177+
line_to_comments = get_fn_call_linenos(test, original_runtimes_dict, optimized_runtimes_dict)
178+
comment_adder = CommentAdder(line_to_comments)
179+
modified_tree = wrapper.visit(comment_adder)
240180
modified_source = modified_tree.code
241-
242-
# Create a new GeneratedTests object with the modified source
243181
modified_test = GeneratedTests(
244182
generated_original_test_source=modified_source,
245183
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
@@ -254,3 +192,28 @@ def add_runtime_comments_to_generated_tests(
254192
modified_tests.append(test)
255193

256194
return GeneratedTestsList(generated_tests=modified_tests)
195+
196+
197+
def remove_functions_from_generated_tests(
198+
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
199+
) -> GeneratedTestsList:
200+
new_generated_tests = []
201+
for generated_test in generated_tests.generated_tests:
202+
for test_function in test_functions_to_remove:
203+
function_pattern = re.compile(
204+
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
205+
re.DOTALL,
206+
)
207+
208+
match = function_pattern.search(generated_test.generated_original_test_source)
209+
210+
if match is None or "@pytest.mark.parametrize" in match.group(0):
211+
continue
212+
213+
generated_test.generated_original_test_source = function_pattern.sub(
214+
"", generated_test.generated_original_test_source
215+
)
216+
217+
new_generated_tests.append(generated_test)
218+
219+
return GeneratedTestsList(generated_tests=new_generated_tests)

0 commit comments

Comments
 (0)