55import re
66from pathlib import Path
77from textwrap import dedent
8- from typing import TYPE_CHECKING , Union , Optional
8+ from typing import TYPE_CHECKING
99
1010import libcst as cst
1111
@@ -52,30 +52,21 @@ class CfoVisitor(ast.NodeVisitor):
5252
5353 def __init__ (self , qualifed_name : str , source_code : str ) -> None :
5454 self .source_lines = source_code .splitlines ()
55- self .name = qualifed_name .split ('.' )[- 1 ]
55+ self .name = qualifed_name .split ("." )[- 1 ]
5656 self .results : list [int ] = [] # map actual line number to line number in ast
5757
58- def visit_Call (self , node ):
59- """
60- Detect calls to:
61- - myfunc(...)
62- - obj.myfunc(...)
63- """
58+ def visit_Call (self , node ): # noqa: ANN201, ANN001
59+ """Detect fn calls."""
6460 func_name = self ._get_called_func_name (node .func )
6561 if func_name == self .name :
66- self .results .append (node .lineno )
62+ self .results .append (node .lineno - 1 )
6763 self .generic_visit (node )
6864
69- def _get_called_func_name (self , node ):
70- """
71- Given a node like:
72- - Name(id='myfunc')
73- - Attribute(value=..., attr='myfunc')
74- Return the final function name if possible.
75- """
65+ def _get_called_func_name (self , node ): # noqa: ANN001, ANN202
66+ """Return name of called fn."""
7667 if isinstance (node , ast .Name ):
7768 return node .id
78- elif isinstance (node , ast .Attribute ):
69+ if isinstance (node , ast .Attribute ):
7970 return node .attr
8071 return None
8172
@@ -101,7 +92,9 @@ def add_runtime_comments_to_generated_tests(
10192
10293 # TODO: reduce for loops to one
10394 class RuntimeCommentTransformer (cst .CSTTransformer ):
104- def __init__ (self , module : cst .Module , test : GeneratedTests , tests_root : Path , rel_tests_root : Path ) -> None :
95+ def __init__ (
96+ self , qualified_name : str , module : cst .Module , test : GeneratedTests , tests_root : Path , rel_tests_root : Path
97+ ) -> None :
10598 super ().__init__ ()
10699 self .test = test
107100 self .context_stack : list [str ] = []
@@ -110,6 +103,7 @@ def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, r
110103 self .module = module
111104 self .cfo_locs : list [int ] = []
112105 self .cfo_idx_loc_to_look_at : int = - 1
106+ self .name = qualified_name .split ("." )[- 1 ]
113107
114108 def visit_ClassDef (self , node : cst .ClassDef ) -> None :
115109 # Track when we enter a class
@@ -135,14 +129,11 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
135129 self .context_stack .pop ()
136130 return updated_node
137131
138- def leave_Call (self , node : cst .Call , updated_node : cst .Call ) -> cst .Call :
139- """
140- Detect calls to:
141- - myfunc()
142- - obj.myfunc()
143- """
144- func_name = self ._get_called_func_name (node .func )
145- if func_name == self .name :
132+ def leave_SimpleStatementLine (
133+ self , original_node : cst .SimpleStatementLine , updated_node : cst .SimpleStatementLine
134+ ) -> cst .SimpleStatementLine :
135+ # Check if this statement line contains a call to self.name
136+ if self ._contains_myfunc_call (updated_node ):
146137 # Find matching test cases by looking for this test function name in the test results
147138 self .cfo_idx_loc_to_look_at += 1
148139 matching_original_times = []
@@ -209,23 +200,36 @@ def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call:
209200 comment_text = (
210201 f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain } % { status } )"
211202 )
212-
213- whitespace = updated_node .trailing_whitespace
214- new_trailing_whitespace = whitespace .with_changes (comment = cst .Comment (comment_text ))
215- return updated_node .with_changes (trailing_whitespace = new_trailing_whitespace )
203+ return updated_node .with_changes (
204+ trailing_whitespace = cst .TrailingWhitespace (
205+ whitespace = cst .SimpleWhitespace (" " ),
206+ comment = cst .Comment (comment_text ),
207+ newline = updated_node .trailing_whitespace .newline ,
208+ )
209+ )
216210 return updated_node
217211
218- def _get_called_func_name (self , node ):
219- """
220- Extract the last part of the function name:
221- - cst.Name(value='myfunc')
222- - cst.Attribute(attr=cst.Name(value='myfunc'))
223- """
224- if isinstance (node , cst .Name ):
225- return node .value
226- elif isinstance (node , cst .Attribute ):
227- return node .attr .value
228- return None
212+ def _contains_myfunc_call (self , node ):
213+ """Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
214+
215+ class Finder (cst .CSTVisitor ):
216+ def __init__ (self , name : str ):
217+ super ().__init__ ()
218+ self .found = False
219+ self .name = name
220+
221+ def visit_Call (self , call_node ):
222+ func_expr = call_node .func
223+ if isinstance (func_expr , cst .Name ):
224+ if func_expr .value == self .name :
225+ self .found = True
226+ elif isinstance (func_expr , cst .Attribute ):
227+ if func_expr .attr .value == self .name :
228+ self .found = True
229+
230+ finder = Finder (self .name )
231+ node .visit (finder )
232+ return finder .found
229233
230234 # Process each generated test
231235 modified_tests = []
@@ -234,7 +238,8 @@ def _get_called_func_name(self, node):
234238 # Parse the test source code
235239 tree = cst .parse_module (test .generated_original_test_source )
236240 # Transform the tree to add runtime comments
237- transformer = RuntimeCommentTransformer (tree , test , tests_root , rel_tests_root )
241+ # qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
242+ transformer = RuntimeCommentTransformer (qualifed_name , tree , test , tests_root , rel_tests_root )
238243 modified_tree = tree .visit (transformer )
239244
240245 # Convert back to source code
0 commit comments