Skip to content

Commit 4a0abd5

Browse files
committed
todo fix precommit
1 parent 9316ee7 commit 4a0abd5

File tree

3 files changed

+106
-86
lines changed

3 files changed

+106
-86
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
from pathlib import Path
77
from textwrap import dedent
8-
from typing import TYPE_CHECKING, Union, Optional
8+
from typing import TYPE_CHECKING
99

1010
import libcst as cst
1111

@@ -52,30 +52,21 @@ class CfoVisitor(ast.NodeVisitor):
5252

5353
def __init__(self, qualifed_name: str, source_code: str) -> None:
5454
self.source_lines = source_code.splitlines()
55-
self.name = qualifed_name.split('.')[-1]
55+
self.name = qualifed_name.split(".")[-1]
5656
self.results: list[int] = [] # map actual line number to line number in ast
5757

58-
def visit_Call(self, node):
59-
"""
60-
Detect calls to:
61-
- myfunc(...)
62-
- obj.myfunc(...)
63-
"""
58+
def visit_Call(self, node): # noqa: ANN201, ANN001
59+
"""Detect fn calls."""
6460
func_name = self._get_called_func_name(node.func)
6561
if func_name == self.name:
66-
self.results.append(node.lineno)
62+
self.results.append(node.lineno - 1)
6763
self.generic_visit(node)
6864

69-
def _get_called_func_name(self, node):
70-
"""
71-
Given a node like:
72-
- Name(id='myfunc')
73-
- Attribute(value=..., attr='myfunc')
74-
Return the final function name if possible.
75-
"""
65+
def _get_called_func_name(self, node): # noqa: ANN001, ANN202
66+
"""Return name of called fn."""
7667
if isinstance(node, ast.Name):
7768
return node.id
78-
elif isinstance(node, ast.Attribute):
69+
if isinstance(node, ast.Attribute):
7970
return node.attr
8071
return None
8172

@@ -101,7 +92,9 @@ def add_runtime_comments_to_generated_tests(
10192

10293
# TODO: reduce for loops to one
10394
class RuntimeCommentTransformer(cst.CSTTransformer):
104-
def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
95+
def __init__(
96+
self, qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
97+
) -> None:
10598
super().__init__()
10699
self.test = test
107100
self.context_stack: list[str] = []
@@ -110,6 +103,7 @@ def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, r
110103
self.module = module
111104
self.cfo_locs: list[int] = []
112105
self.cfo_idx_loc_to_look_at: int = -1
106+
self.name = qualified_name.split(".")[-1]
113107

114108
def visit_ClassDef(self, node: cst.ClassDef) -> None:
115109
# Track when we enter a class
@@ -135,14 +129,11 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
135129
self.context_stack.pop()
136130
return updated_node
137131

138-
def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call:
139-
"""
140-
Detect calls to:
141-
- myfunc()
142-
- obj.myfunc()
143-
"""
144-
func_name = self._get_called_func_name(node.func)
145-
if func_name == self.name:
132+
def leave_SimpleStatementLine(
133+
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
134+
) -> cst.SimpleStatementLine:
135+
# Check if this statement line contains a call to self.name
136+
if self._contains_myfunc_call(updated_node):
146137
# Find matching test cases by looking for this test function name in the test results
147138
self.cfo_idx_loc_to_look_at += 1
148139
matching_original_times = []
@@ -209,23 +200,36 @@ def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call:
209200
comment_text = (
210201
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
211202
)
212-
213-
whitespace = updated_node.trailing_whitespace
214-
new_trailing_whitespace = whitespace.with_changes(comment=cst.Comment(comment_text))
215-
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
203+
return updated_node.with_changes(
204+
trailing_whitespace=cst.TrailingWhitespace(
205+
whitespace=cst.SimpleWhitespace(" "),
206+
comment=cst.Comment(comment_text),
207+
newline=updated_node.trailing_whitespace.newline,
208+
)
209+
)
216210
return updated_node
217211

218-
def _get_called_func_name(self, node):
219-
"""
220-
Extract the last part of the function name:
221-
- cst.Name(value='myfunc')
222-
- cst.Attribute(attr=cst.Name(value='myfunc'))
223-
"""
224-
if isinstance(node, cst.Name):
225-
return node.value
226-
elif isinstance(node, cst.Attribute):
227-
return node.attr.value
228-
return None
212+
def _contains_myfunc_call(self, node):
213+
"""Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
214+
215+
class Finder(cst.CSTVisitor):
216+
def __init__(self, name: str):
217+
super().__init__()
218+
self.found = False
219+
self.name = name
220+
221+
def visit_Call(self, call_node):
222+
func_expr = call_node.func
223+
if isinstance(func_expr, cst.Name):
224+
if func_expr.value == self.name:
225+
self.found = True
226+
elif isinstance(func_expr, cst.Attribute):
227+
if func_expr.attr.value == self.name:
228+
self.found = True
229+
230+
finder = Finder(self.name)
231+
node.visit(finder)
232+
return finder.found
229233

230234
# Process each generated test
231235
modified_tests = []
@@ -234,7 +238,8 @@ def _get_called_func_name(self, node):
234238
# Parse the test source code
235239
tree = cst.parse_module(test.generated_original_test_source)
236240
# Transform the tree to add runtime comments
237-
transformer = RuntimeCommentTransformer(tree, test, tests_root, rel_tests_root)
241+
# qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
242+
transformer = RuntimeCommentTransformer(qualifed_name, tree, test, tests_root, rel_tests_root)
238243
modified_tree = tree.visit(transformer)
239244

240245
# Convert back to source code

codeflash/optimization/function_optimizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,11 @@ def find_and_process_best_optimization(
10151015
qualifed_name = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
10161016
# Add runtime comments to generated tests before creating the PR
10171017
generated_tests = add_runtime_comments_to_generated_tests(
1018-
qualifed_name, self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
1018+
qualifed_name,
1019+
self.test_cfg,
1020+
generated_tests,
1021+
original_runtime_by_test,
1022+
optimized_runtime_by_test,
10191023
)
10201024
generated_tests_str = "\n\n".join(
10211025
[test.generated_original_test_source for test in generated_tests.generated_tests]

0 commit comments

Comments
 (0)