@@ -50,9 +50,9 @@ class CfoVisitor(ast.NodeVisitor):
5050 and reports their location relative to the function they're in.
5151 """
5252
53- def __init__ (self , qualified_name : str , source_code : str ) -> None :
53+ def __init__ (self , function_name : str , source_code : str ) -> None :
5454 self .source_lines = source_code .splitlines ()
55- self .name = qualified_name . split ( "." )[ - 1 ]
55+ self .name = function_name
5656 self .results : list [int ] = [] # map actual line number to line number in ast
5757
5858 def visit_Call (self , node ): # type: ignore[no-untyped-def] # noqa: ANN201, ANN001
@@ -71,13 +71,166 @@ def _get_called_func_name(self, node): # type: ignore[no-untyped-def] # noqa: A
7171 return None
7272
7373
74- def find_codeflash_output_assignments (qualified_name : str , source_code : str ) -> list [int ]:
74+ def find_codeflash_output_assignments (function_name : str , source_code : str ) -> list [int ]:
7575 tree = ast .parse (source_code )
76- visitor = CfoVisitor (qualified_name , source_code )
76+ visitor = CfoVisitor (function_name , source_code )
7777 visitor .visit (tree )
7878 return visitor .results
7979
8080
81+ class Finder (cst .CSTVisitor ):
82+ def __init__ (self , name : str ) -> None :
83+ super ().__init__ ()
84+ self .found = False
85+ self .name = name
86+
87+ def visit_Call (self , call_node ) -> None : # type: ignore[no-untyped-def] # noqa : ANN001
88+ func_expr = call_node .func
89+ if isinstance (func_expr , cst .Name ):
90+ if func_expr .value == self .name :
91+ self .found = True
92+ elif isinstance (func_expr , cst .Attribute ): # noqa : SIM102
93+ if func_expr .attr .value == self .name :
94+ self .found = True
95+
96+
97+ # TODO: reduce for loops to one
98+ class RuntimeCommentTransformer (cst .CSTTransformer ):
99+ def __init__ (
100+ self ,
101+ qualified_name : str ,
102+ module : cst .Module ,
103+ test : GeneratedTests ,
104+ tests_root : Path ,
105+ rel_tests_root : Path ,
106+ original_runtimes : dict [InvocationId , list [int ]],
107+ optimized_runtimes : dict [InvocationId , list [int ]],
108+ ) -> None :
109+ super ().__init__ ()
110+ self .test = test
111+ self .context_stack : list [str ] = []
112+ self .tests_root = tests_root
113+ self .rel_tests_root = rel_tests_root
114+ self .module = module
115+ self .cfo_locs : list [int ] = []
116+ self .cfo_idx_loc_to_look_at : int = - 1
117+ self .name = qualified_name .split ("." )[- 1 ]
118+ self .original_runtimes = original_runtimes
119+ self .optimized_runtimes = optimized_runtimes
120+
121+ def visit_ClassDef (self , node : cst .ClassDef ) -> None :
122+ # Track when we enter a class
123+ self .context_stack .append (node .name .value )
124+
125+ def leave_ClassDef (self , original_node : cst .ClassDef , updated_node : cst .ClassDef ) -> cst .ClassDef : # noqa: ARG002
126+ # Pop the context when we leave a class
127+ self .context_stack .pop ()
128+ return updated_node
129+
130+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
131+ # convert function body to ast normalized string and find occurrences of codeflash_output
132+ body_code = dedent (self .module .code_for_node (node .body ))
133+ normalized_body_code = ast .unparse (ast .parse (body_code ))
134+ self .cfo_locs = sorted (
135+ find_codeflash_output_assignments (self .name , normalized_body_code )
136+ ) # sorted in order we will encounter them
137+ self .cfo_idx_loc_to_look_at = - 1
138+ self .context_stack .append (node .name .value )
139+
140+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
141+ # Pop the context when we leave a function
142+ self .context_stack .pop ()
143+ return updated_node
144+
145+ def leave_SimpleStatementLine (
146+ self ,
147+ original_node : cst .SimpleStatementLine , # noqa: ARG002
148+ updated_node : cst .SimpleStatementLine ,
149+ ) -> cst .SimpleStatementLine :
150+ # Check if this statement line contains a call to self.name
151+ if self ._contains_myfunc_call (updated_node ): # type: ignore[no-untyped-call]
152+ # Find matching test cases by looking for this test function name in the test results
153+ self .cfo_idx_loc_to_look_at += 1
154+ matching_original_times = []
155+ matching_optimized_times = []
156+ # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
157+ for invocation_id , runtimes in self .original_runtimes .items ():
158+ # get position here and match in if condition
159+ qualified_name = (
160+ invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
161+ if invocation_id .test_class_name
162+ else invocation_id .test_function_name
163+ )
164+ rel_path = (
165+ Path (invocation_id .test_module_path .replace ("." , os .sep ))
166+ .with_suffix (".py" )
167+ .relative_to (self .rel_tests_root )
168+ )
169+ if (
170+ qualified_name == "." .join (self .context_stack )
171+ and rel_path
172+ in [
173+ self .test .behavior_file_path .relative_to (self .tests_root ),
174+ self .test .perf_file_path .relative_to (self .tests_root ),
175+ ]
176+ and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
177+ ):
178+ matching_original_times .extend (runtimes )
179+
180+ for invocation_id , runtimes in self .optimized_runtimes .items ():
181+ # get position here and match in if condition
182+ qualified_name = (
183+ invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
184+ if invocation_id .test_class_name
185+ else invocation_id .test_function_name
186+ )
187+ rel_path = (
188+ Path (invocation_id .test_module_path .replace ("." , os .sep ))
189+ .with_suffix (".py" )
190+ .relative_to (self .rel_tests_root )
191+ )
192+ if (
193+ qualified_name == "." .join (self .context_stack )
194+ and rel_path
195+ in [
196+ self .test .behavior_file_path .relative_to (self .tests_root ),
197+ self .test .perf_file_path .relative_to (self .tests_root ),
198+ ]
199+ and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
200+ ):
201+ matching_optimized_times .extend (runtimes )
202+
203+ if matching_original_times and matching_optimized_times :
204+ original_time = min (matching_original_times )
205+ optimized_time = min (matching_optimized_times )
206+ if original_time != 0 and optimized_time != 0 :
207+ perf_gain = format_perf (
208+ abs (
209+ performance_gain (original_runtime_ns = original_time , optimized_runtime_ns = optimized_time )
210+ * 100
211+ )
212+ )
213+ status = "slower" if optimized_time > original_time else "faster"
214+ # Create the runtime comment
215+ comment_text = (
216+ f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain } % { status } )"
217+ )
218+ return updated_node .with_changes (
219+ trailing_whitespace = cst .TrailingWhitespace (
220+ whitespace = cst .SimpleWhitespace (" " ),
221+ comment = cst .Comment (comment_text ),
222+ newline = updated_node .trailing_whitespace .newline ,
223+ )
224+ )
225+ return updated_node
226+
227+ def _contains_myfunc_call (self , node ): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001
228+ """Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
229+ finder = Finder (self .name )
230+ node .visit (finder )
231+ return finder .found
232+
233+
81234def add_runtime_comments_to_generated_tests (
82235 qualified_name : str ,
83236 test_cfg : TestConfig ,
@@ -90,149 +243,6 @@ def add_runtime_comments_to_generated_tests(
90243 module_root = test_cfg .project_root_path
91244 rel_tests_root = tests_root .relative_to (module_root )
92245
93- # TODO: reduce for loops to one
94- class RuntimeCommentTransformer (cst .CSTTransformer ):
95- def __init__ (
96- self , qualified_name : str , module : cst .Module , test : GeneratedTests , tests_root : Path , rel_tests_root : Path
97- ) -> None :
98- super ().__init__ ()
99- self .test = test
100- self .context_stack : list [str ] = []
101- self .tests_root = tests_root
102- self .rel_tests_root = rel_tests_root
103- self .module = module
104- self .cfo_locs : list [int ] = []
105- self .cfo_idx_loc_to_look_at : int = - 1
106- self .name = qualified_name .split ("." )[- 1 ]
107-
108- def visit_ClassDef (self , node : cst .ClassDef ) -> None :
109- # Track when we enter a class
110- self .context_stack .append (node .name .value )
111-
112- def leave_ClassDef (self , original_node : cst .ClassDef , updated_node : cst .ClassDef ) -> cst .ClassDef : # noqa: ARG002
113- # Pop the context when we leave a class
114- self .context_stack .pop ()
115- return updated_node
116-
117- def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
118- # convert function body to ast normalized string and find occurrences of codeflash_output
119- body_code = dedent (self .module .code_for_node (node .body ))
120- normalized_body_code = ast .unparse (ast .parse (body_code ))
121- self .cfo_locs = sorted (
122- find_codeflash_output_assignments (qualified_name , normalized_body_code )
123- ) # sorted in order we will encounter them
124- self .cfo_idx_loc_to_look_at = - 1
125- self .context_stack .append (node .name .value )
126-
127- def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
128- # Pop the context when we leave a function
129- self .context_stack .pop ()
130- return updated_node
131-
132- def leave_SimpleStatementLine (
133- self ,
134- original_node : cst .SimpleStatementLine , # noqa: ARG002
135- updated_node : cst .SimpleStatementLine ,
136- ) -> cst .SimpleStatementLine :
137- # Check if this statement line contains a call to self.name
138- if self ._contains_myfunc_call (updated_node ): # type: ignore[no-untyped-call]
139- # Find matching test cases by looking for this test function name in the test results
140- self .cfo_idx_loc_to_look_at += 1
141- matching_original_times = []
142- matching_optimized_times = []
143- # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
144- for invocation_id , runtimes in original_runtimes .items ():
145- # get position here and match in if condition
146- qualified_name = (
147- invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
148- if invocation_id .test_class_name
149- else invocation_id .test_function_name
150- )
151- rel_path = (
152- Path (invocation_id .test_module_path .replace ("." , os .sep ))
153- .with_suffix (".py" )
154- .relative_to (self .rel_tests_root )
155- )
156- if (
157- qualified_name == "." .join (self .context_stack )
158- and rel_path
159- in [
160- self .test .behavior_file_path .relative_to (self .tests_root ),
161- self .test .perf_file_path .relative_to (self .tests_root ),
162- ]
163- and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
164- ):
165- matching_original_times .extend (runtimes )
166-
167- for invocation_id , runtimes in optimized_runtimes .items ():
168- # get position here and match in if condition
169- qualified_name = (
170- invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
171- if invocation_id .test_class_name
172- else invocation_id .test_function_name
173- )
174- rel_path = (
175- Path (invocation_id .test_module_path .replace ("." , os .sep ))
176- .with_suffix (".py" )
177- .relative_to (self .rel_tests_root )
178- )
179- if (
180- qualified_name == "." .join (self .context_stack )
181- and rel_path
182- in [
183- self .test .behavior_file_path .relative_to (self .tests_root ),
184- self .test .perf_file_path .relative_to (self .tests_root ),
185- ]
186- and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
187- ):
188- matching_optimized_times .extend (runtimes )
189-
190- if matching_original_times and matching_optimized_times :
191- original_time = min (matching_original_times )
192- optimized_time = min (matching_optimized_times )
193- if original_time != 0 and optimized_time != 0 :
194- perf_gain = format_perf (
195- abs (
196- performance_gain (original_runtime_ns = original_time , optimized_runtime_ns = optimized_time )
197- * 100
198- )
199- )
200- status = "slower" if optimized_time > original_time else "faster"
201- # Create the runtime comment
202- comment_text = (
203- f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain } % { status } )"
204- )
205- return updated_node .with_changes (
206- trailing_whitespace = cst .TrailingWhitespace (
207- whitespace = cst .SimpleWhitespace (" " ),
208- comment = cst .Comment (comment_text ),
209- newline = updated_node .trailing_whitespace .newline ,
210- )
211- )
212- return updated_node
213-
214- def _contains_myfunc_call (self , node ): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001
215- """Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
216-
217- class Finder (cst .CSTVisitor ):
218- def __init__ (self , name : str ) -> None :
219- super ().__init__ ()
220- self .found = False
221- self .name = name
222-
223- def visit_Call (self , call_node ) -> None : # type: ignore[no-untyped-def] # noqa : ANN001
224- func_expr = call_node .func
225- if isinstance (func_expr , cst .Name ):
226- if func_expr .value == self .name :
227- self .found = True
228- elif isinstance (func_expr , cst .Attribute ): # noqa : SIM102
229- if func_expr .attr .value == self .name :
230- self .found = True
231-
232- finder = Finder (self .name )
233- node .visit (finder )
234- return finder .found
235-
236246 # Process each generated test
237247 modified_tests = []
238248 for test in generated_tests .generated_tests :
@@ -241,7 +251,9 @@ def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa
241251 tree = cst .parse_module (test .generated_original_test_source )
242252 # Transform the tree to add runtime comments
243253 # qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
244- transformer = RuntimeCommentTransformer (qualified_name , tree , test , tests_root , rel_tests_root )
254+ transformer = RuntimeCommentTransformer (
255+ qualified_name , tree , test , tests_root , rel_tests_root , original_runtimes , optimized_runtimes
256+ )
245257 modified_tree = tree .visit (transformer )
246258
247259 # Convert back to source code
0 commit comments