Skip to content

Commit 9316ee7

Browse files
committed
wip
1 parent 67a9e79 commit 9316ee7

File tree

2 files changed

+55
-102
lines changed

2 files changed

+55
-102
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 52 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
from pathlib import Path
77
from textwrap import dedent
8-
from typing import TYPE_CHECKING, Union
8+
from typing import TYPE_CHECKING, Union, Optional
99

1010
import libcst as cst
1111

@@ -50,89 +50,45 @@ class CfoVisitor(ast.NodeVisitor):
5050
and reports their location relative to the function they're in.
5151
"""
5252

53-
def __init__(self, source_code: str) -> None:
53+
def __init__(self, qualifed_name: str, source_code: str) -> None:
5454
self.source_lines = source_code.splitlines()
55+
self.name = qualifed_name.split('.')[-1]
5556
self.results: list[int] = [] # map actual line number to line number in ast
5657

57-
def _is_codeflash_output_target(self, target: Union[ast.expr, list]) -> bool: # type: ignore[type-arg]
58-
"""Check if the assignment target is the variable 'codeflash_output'."""
59-
if isinstance(target, ast.Name):
60-
return target.id == "codeflash_output"
61-
if isinstance(target, (ast.Tuple, ast.List)):
62-
# Handle tuple/list unpacking: a, codeflash_output, b = values
63-
return any(self._is_codeflash_output_target(elt) for elt in target.elts)
64-
if isinstance(target, (ast.Subscript, ast.Attribute)):
65-
# Not a simple variable assignment
66-
return False
67-
return False
68-
69-
def _record_assignment(self, node: ast.AST) -> None:
70-
"""Record an assignment to codeflash_output."""
71-
relative_line = node.lineno - 1 # type: ignore[attr-defined]
72-
self.results.append(relative_line)
73-
74-
def visit_Assign(self, node: ast.Assign) -> None:
75-
"""Visit assignment statements: codeflash_output = value."""
76-
for target in node.targets:
77-
if self._is_codeflash_output_target(target):
78-
self._record_assignment(node)
79-
break
58+
def visit_Call(self, node):
59+
"""
60+
Detect calls to:
61+
- myfunc(...)
62+
- obj.myfunc(...)
63+
"""
64+
func_name = self._get_called_func_name(node.func)
65+
if func_name == self.name:
66+
self.results.append(node.lineno)
8067
self.generic_visit(node)
8168

82-
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
83-
"""Visit annotated assignments: codeflash_output: int = value."""
84-
if self._is_codeflash_output_target(node.target):
85-
self._record_assignment(node)
86-
self.generic_visit(node)
87-
88-
def visit_AugAssign(self, node: ast.AugAssign) -> None:
89-
"""Visit augmented assignments: codeflash_output += value."""
90-
if self._is_codeflash_output_target(node.target):
91-
self._record_assignment(node)
92-
self.generic_visit(node)
93-
94-
def visit_NamedExpr(self, node: ast.NamedExpr) -> None:
95-
"""Visit walrus operator: (codeflash_output := value)."""
96-
if isinstance(node.target, ast.Name) and node.target.id == "codeflash_output":
97-
self._record_assignment(node)
98-
self.generic_visit(node)
99-
100-
def visit_For(self, node: ast.For) -> None:
101-
"""Visit for loops: for codeflash_output in iterable."""
102-
if self._is_codeflash_output_target(node.target):
103-
self._record_assignment(node)
104-
self.generic_visit(node)
105-
106-
def visit_comprehension(self, node: ast.comprehension) -> None:
107-
"""Visit comprehensions: [x for codeflash_output in iterable]."""
108-
if self._is_codeflash_output_target(node.target):
109-
# Comprehensions don't have line numbers, so we skip recording
110-
pass
111-
self.generic_visit(node)
112-
113-
def visit_With(self, node: ast.With) -> None:
114-
"""Visit with statements: with expr as codeflash_output."""
115-
for item in node.items:
116-
if item.optional_vars and self._is_codeflash_output_target(item.optional_vars):
117-
self._record_assignment(node)
118-
break
119-
self.generic_visit(node)
120-
121-
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
122-
"""Visit except handlers: except Exception as codeflash_output."""
123-
if node.name == "codeflash_output":
124-
self._record_assignment(node)
125-
self.generic_visit(node)
126-
127-
128-
def find_codeflash_output_assignments(source_code: str) -> list[int]:
69+
def _get_called_func_name(self, node):
70+
"""
71+
Given a node like:
72+
- Name(id='myfunc')
73+
- Attribute(value=..., attr='myfunc')
74+
Return the final function name if possible.
75+
"""
76+
if isinstance(node, ast.Name):
77+
return node.id
78+
elif isinstance(node, ast.Attribute):
79+
return node.attr
80+
return None
81+
82+
83+
def find_codeflash_output_assignments(qualifed_name: str, source_code: str) -> list[int]:
12984
tree = ast.parse(source_code)
130-
visitor = CfoVisitor(source_code)
85+
visitor = CfoVisitor(qualifed_name, source_code)
13186
visitor.visit(tree)
13287
return visitor.results
13388

13489

13590
def add_runtime_comments_to_generated_tests(
91+
qualifed_name: str,
13692
test_cfg: TestConfig,
13793
generated_tests: GeneratedTestsList,
13894
original_runtimes: dict[InvocationId, list[int]],
@@ -169,7 +125,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
169125
body_code = dedent(self.module.code_for_node(node.body))
170126
normalized_body_code = ast.unparse(ast.parse(body_code))
171127
self.cfo_locs = sorted(
172-
find_codeflash_output_assignments(normalized_body_code)
128+
find_codeflash_output_assignments(qualifed_name, normalized_body_code)
173129
) # sorted in order we will encounter them
174130
self.cfo_idx_loc_to_look_at = -1
175131
self.context_stack.append(node.name.value)
@@ -179,24 +135,14 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
179135
self.context_stack.pop()
180136
return updated_node
181137

182-
def leave_SimpleStatementLine(
183-
self,
184-
original_node: cst.SimpleStatementLine, # noqa: ARG002
185-
updated_node: cst.SimpleStatementLine,
186-
) -> cst.SimpleStatementLine:
187-
# Look for assignment statements that assign to codeflash_output
188-
# Handle both single statements and multiple statements on one line
189-
codeflash_assignment_found = False
190-
for stmt in updated_node.body:
191-
if isinstance(stmt, cst.Assign) and (
192-
len(stmt.targets) == 1
193-
and isinstance(stmt.targets[0].target, cst.Name)
194-
and stmt.targets[0].target.value == "codeflash_output"
195-
):
196-
codeflash_assignment_found = True
197-
break
198-
199-
if codeflash_assignment_found:
138+
def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call:
139+
"""
140+
Detect calls to:
141+
- myfunc()
142+
- obj.myfunc()
143+
"""
144+
func_name = self._get_called_func_name(node.func)
145+
if func_name == self.name:
200146
# Find matching test cases by looking for this test function name in the test results
201147
self.cfo_idx_loc_to_look_at += 1
202148
matching_original_times = []
@@ -264,17 +210,23 @@ def leave_SimpleStatementLine(
264210
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
265211
)
266212

267-
# Add comment to the trailing whitespace
268-
new_trailing_whitespace = cst.TrailingWhitespace(
269-
whitespace=cst.SimpleWhitespace(" "),
270-
comment=cst.Comment(comment_text),
271-
newline=updated_node.trailing_whitespace.newline,
272-
)
273-
213+
whitespace = updated_node.trailing_whitespace
214+
new_trailing_whitespace = whitespace.with_changes(comment=cst.Comment(comment_text))
274215
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
275-
276216
return updated_node
277217

218+
def _get_called_func_name(self, node):
219+
"""
220+
Extract the last part of the function name:
221+
- cst.Name(value='myfunc')
222+
- cst.Attribute(attr=cst.Name(value='myfunc'))
223+
"""
224+
if isinstance(node, cst.Name):
225+
return node.value
226+
elif isinstance(node, cst.Attribute):
227+
return node.attr.value
228+
return None
229+
278230
# Process each generated test
279231
modified_tests = []
280232
for test in generated_tests.generated_tests:

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,15 +1012,16 @@ def find_and_process_best_optimization(
10121012
optimized_runtime_by_test = (
10131013
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
10141014
)
1015+
qualifed_name = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
10151016
# Add runtime comments to generated tests before creating the PR
10161017
generated_tests = add_runtime_comments_to_generated_tests(
1017-
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
1018+
qualifed_name, self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
10181019
)
10191020
generated_tests_str = "\n\n".join(
10201021
[test.generated_original_test_source for test in generated_tests.generated_tests]
10211022
)
10221023
existing_tests = existing_tests_source_for(
1023-
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
1024+
qualifed_name,
10241025
function_to_all_tests,
10251026
test_cfg=self.test_cfg,
10261027
original_runtimes_all=original_runtime_by_test,

0 commit comments

Comments
 (0)