55import re
66from pathlib import Path
77from textwrap import dedent
8- from typing import TYPE_CHECKING , Union
8+ from typing import TYPE_CHECKING , Union , Optional
99
1010import libcst as cst
1111
@@ -50,89 +50,45 @@ class CfoVisitor(ast.NodeVisitor):
5050 and reports their location relative to the function they're in.
5151 """
5252
53- def __init__ (self , source_code : str ) -> None :
53+ def __init__ (self , qualifed_name : str , source_code : str ) -> None :
5454 self .source_lines = source_code .splitlines ()
55+ self .name = qualifed_name .split ('.' )[- 1 ]
5556 self .results : list [int ] = [] # map actual line number to line number in ast
5657
57- def _is_codeflash_output_target (self , target : Union [ast .expr , list ]) -> bool : # type: ignore[type-arg]
58- """Check if the assignment target is the variable 'codeflash_output'."""
59- if isinstance (target , ast .Name ):
60- return target .id == "codeflash_output"
61- if isinstance (target , (ast .Tuple , ast .List )):
62- # Handle tuple/list unpacking: a, codeflash_output, b = values
63- return any (self ._is_codeflash_output_target (elt ) for elt in target .elts )
64- if isinstance (target , (ast .Subscript , ast .Attribute )):
65- # Not a simple variable assignment
66- return False
67- return False
68-
69- def _record_assignment (self , node : ast .AST ) -> None :
70- """Record an assignment to codeflash_output."""
71- relative_line = node .lineno - 1 # type: ignore[attr-defined]
72- self .results .append (relative_line )
73-
74- def visit_Assign (self , node : ast .Assign ) -> None :
75- """Visit assignment statements: codeflash_output = value."""
76- for target in node .targets :
77- if self ._is_codeflash_output_target (target ):
78- self ._record_assignment (node )
79- break
58+ def visit_Call (self , node ):
59+ """
60+ Detect calls to:
61+ - myfunc(...)
62+ - obj.myfunc(...)
63+ """
64+ func_name = self ._get_called_func_name (node .func )
65+ if func_name == self .name :
66+ self .results .append (node .lineno )
8067 self .generic_visit (node )
8168
82- def visit_AnnAssign (self , node : ast .AnnAssign ) -> None :
83- """Visit annotated assignments: codeflash_output: int = value."""
84- if self ._is_codeflash_output_target (node .target ):
85- self ._record_assignment (node )
86- self .generic_visit (node )
87-
88- def visit_AugAssign (self , node : ast .AugAssign ) -> None :
89- """Visit augmented assignments: codeflash_output += value."""
90- if self ._is_codeflash_output_target (node .target ):
91- self ._record_assignment (node )
92- self .generic_visit (node )
93-
94- def visit_NamedExpr (self , node : ast .NamedExpr ) -> None :
95- """Visit walrus operator: (codeflash_output := value)."""
96- if isinstance (node .target , ast .Name ) and node .target .id == "codeflash_output" :
97- self ._record_assignment (node )
98- self .generic_visit (node )
99-
100- def visit_For (self , node : ast .For ) -> None :
101- """Visit for loops: for codeflash_output in iterable."""
102- if self ._is_codeflash_output_target (node .target ):
103- self ._record_assignment (node )
104- self .generic_visit (node )
105-
106- def visit_comprehension (self , node : ast .comprehension ) -> None :
107- """Visit comprehensions: [x for codeflash_output in iterable]."""
108- if self ._is_codeflash_output_target (node .target ):
109- # Comprehensions don't have line numbers, so we skip recording
110- pass
111- self .generic_visit (node )
112-
113- def visit_With (self , node : ast .With ) -> None :
114- """Visit with statements: with expr as codeflash_output."""
115- for item in node .items :
116- if item .optional_vars and self ._is_codeflash_output_target (item .optional_vars ):
117- self ._record_assignment (node )
118- break
119- self .generic_visit (node )
120-
121- def visit_ExceptHandler (self , node : ast .ExceptHandler ) -> None :
122- """Visit except handlers: except Exception as codeflash_output."""
123- if node .name == "codeflash_output" :
124- self ._record_assignment (node )
125- self .generic_visit (node )
126-
127-
128- def find_codeflash_output_assignments (source_code : str ) -> list [int ]:
69+ def _get_called_func_name (self , node ):
70+ """
71+ Given a node like:
72+ - Name(id='myfunc')
73+ - Attribute(value=..., attr='myfunc')
74+ Return the final function name if possible.
75+ """
76+ if isinstance (node , ast .Name ):
77+ return node .id
78+ elif isinstance (node , ast .Attribute ):
79+ return node .attr
80+ return None
81+
82+
83+ def find_codeflash_output_assignments (qualifed_name : str , source_code : str ) -> list [int ]:
12984 tree = ast .parse (source_code )
130- visitor = CfoVisitor (source_code )
85+ visitor = CfoVisitor (qualifed_name , source_code )
13186 visitor .visit (tree )
13287 return visitor .results
13388
13489
13590def add_runtime_comments_to_generated_tests (
91+ qualifed_name : str ,
13692 test_cfg : TestConfig ,
13793 generated_tests : GeneratedTestsList ,
13894 original_runtimes : dict [InvocationId , list [int ]],
@@ -169,7 +125,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
169125 body_code = dedent (self .module .code_for_node (node .body ))
170126 normalized_body_code = ast .unparse (ast .parse (body_code ))
171127 self .cfo_locs = sorted (
172- find_codeflash_output_assignments (normalized_body_code )
128+ find_codeflash_output_assignments (qualifed_name , normalized_body_code )
173129 ) # sorted in order we will encounter them
174130 self .cfo_idx_loc_to_look_at = - 1
175131 self .context_stack .append (node .name .value )
@@ -179,24 +135,14 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
179135 self .context_stack .pop ()
180136 return updated_node
181137
182- def leave_SimpleStatementLine (
183- self ,
184- original_node : cst .SimpleStatementLine , # noqa: ARG002
185- updated_node : cst .SimpleStatementLine ,
186- ) -> cst .SimpleStatementLine :
187- # Look for assignment statements that assign to codeflash_output
188- # Handle both single statements and multiple statements on one line
189- codeflash_assignment_found = False
190- for stmt in updated_node .body :
191- if isinstance (stmt , cst .Assign ) and (
192- len (stmt .targets ) == 1
193- and isinstance (stmt .targets [0 ].target , cst .Name )
194- and stmt .targets [0 ].target .value == "codeflash_output"
195- ):
196- codeflash_assignment_found = True
197- break
198-
199- if codeflash_assignment_found :
138+ def leave_Call (self , node : cst .Call , updated_node : cst .Call ) -> cst .Call :
139+ """
140+ Detect calls to:
141+ - myfunc()
142+ - obj.myfunc()
143+ """
144+ func_name = self ._get_called_func_name (node .func )
145+ if func_name == self .name :
200146 # Find matching test cases by looking for this test function name in the test results
201147 self .cfo_idx_loc_to_look_at += 1
202148 matching_original_times = []
@@ -264,17 +210,23 @@ def leave_SimpleStatementLine(
264210 f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain } % { status } )"
265211 )
266212
267- # Add comment to the trailing whitespace
268- new_trailing_whitespace = cst .TrailingWhitespace (
269- whitespace = cst .SimpleWhitespace (" " ),
270- comment = cst .Comment (comment_text ),
271- newline = updated_node .trailing_whitespace .newline ,
272- )
273-
213+ whitespace = updated_node .trailing_whitespace
214+ new_trailing_whitespace = whitespace .with_changes (comment = cst .Comment (comment_text ))
274215 return updated_node .with_changes (trailing_whitespace = new_trailing_whitespace )
275-
276216 return updated_node
277217
218+ def _get_called_func_name (self , node ):
219+ """
220+ Extract the last part of the function name:
221+ - cst.Name(value='myfunc')
222+ - cst.Attribute(attr=cst.Name(value='myfunc'))
223+ """
224+ if isinstance (node , cst .Name ):
225+ return node .value
226+ elif isinstance (node , cst .Attribute ):
227+ return node .attr .value
228+ return None
229+
278230 # Process each generated test
279231 modified_tests = []
280232 for test in generated_tests .generated_tests :
0 commit comments