1111
1212from codeflash .cli_cmds .console import logger
1313from codeflash .code_utils .time_utils import format_perf , format_time
14- from codeflash .models .models import GeneratedTests , GeneratedTestsList
14+ from codeflash .models .models import (GeneratedTests , GeneratedTestsList ,
15+ InvocationId )
1516from codeflash .result .critic import performance_gain
17+ from codeflash .verification .verification_utils import TestConfig
1618
1719if TYPE_CHECKING :
1820 from codeflash .models .models import InvocationId
@@ -90,7 +92,35 @@ def add_runtime_comments_to_generated_tests(
9092 module_root = test_cfg .project_root_path
9193 rel_tests_root = tests_root .relative_to (module_root )
9294
93- # TODO: reduce for loops to one
95+ # ---- Preindex invocation results for O(1) matching -------
96+ # (rel_path, qualified_name, cfo_loc) -> list[runtimes]
97+ def _make_index (invocations ):
98+ index = {}
99+ for invocation_id , runtimes in invocations .items ():
100+ test_class = invocation_id .test_class_name
101+ test_func = invocation_id .test_function_name
102+ q_name = f"{ test_class } .{ test_func } " if test_class else test_func
103+ rel_path = Path (invocation_id .test_module_path .replace ("." , os .sep )).with_suffix (".py" )
104+ # Defensive: sometimes path processing can fail, fallback to string
105+ try :
106+ rel_path = rel_path .relative_to (rel_tests_root )
107+ except Exception :
108+ rel_path = str (rel_path )
109+ # Get CFO location integer
110+ try :
111+ cfo_loc = int (invocation_id .iteration_id .split ("_" )[0 ])
112+ except Exception :
113+ cfo_loc = None
114+ key = (str (rel_path ), q_name , cfo_loc )
115+ if key not in index :
116+ index [key ] = []
117+ index [key ].extend (runtimes )
118+ return index
119+
120+ orig_index = _make_index (original_runtimes )
121+ opt_index = _make_index (optimized_runtimes )
122+
123+ # Optimized fast CST visitor base
94124 class RuntimeCommentTransformer (cst .CSTTransformer ):
95125 def __init__ (
96126 self , qualified_name : str , module : cst .Module , test : GeneratedTests , tests_root : Path , rel_tests_root : Path
@@ -104,104 +134,66 @@ def __init__(
104134 self .cfo_locs : list [int ] = []
105135 self .cfo_idx_loc_to_look_at : int = - 1
106136 self .name = qualified_name .split ("." )[- 1 ]
137+ # Precompute test-local file relative paths for efficiency
138+ self .test_rel_behavior = str (test .behavior_file_path .relative_to (tests_root ))
139+ self .test_rel_perf = str (test .perf_file_path .relative_to (tests_root ))
107140
108141 def visit_ClassDef (self , node : cst .ClassDef ) -> None :
109- # Track when we enter a class
110142 self .context_stack .append (node .name .value )
111143
112- def leave_ClassDef (self , original_node : cst .ClassDef , updated_node : cst .ClassDef ) -> cst .ClassDef : # noqa: ARG002
113- # Pop the context when we leave a class
144+ def leave_ClassDef (self , original_node : cst .ClassDef , updated_node : cst .ClassDef ) -> cst .ClassDef :
114145 self .context_stack .pop ()
115146 return updated_node
116147
117148 def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
118- # convert function body to ast normalized string and find occurrences of codeflash_output
149+ # This could be optimized further if you access CFO assignments via CST
119150 body_code = dedent (self .module .code_for_node (node .body ))
120151 normalized_body_code = ast .unparse (ast .parse (body_code ))
121- self .cfo_locs = sorted (
122- find_codeflash_output_assignments (qualified_name , normalized_body_code )
123- ) # sorted in order we will encounter them
152+ self .cfo_locs = sorted (find_codeflash_output_assignments (qualified_name , normalized_body_code ))
124153 self .cfo_idx_loc_to_look_at = - 1
125154 self .context_stack .append (node .name .value )
126155
127- def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
128- # Pop the context when we leave a function
156+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
129157 self .context_stack .pop ()
130158 return updated_node
131159
132160 def leave_SimpleStatementLine (
133- self ,
134- original_node : cst .SimpleStatementLine , # noqa: ARG002
135- updated_node : cst .SimpleStatementLine ,
161+ self , original_node : cst .SimpleStatementLine , updated_node : cst .SimpleStatementLine
136162 ) -> cst .SimpleStatementLine :
137- # Check if this statement line contains a call to self.name
138- if self ._contains_myfunc_call (updated_node ): # type: ignore[no-untyped-call]
139- # Find matching test cases by looking for this test function name in the test results
163+ # Fast skip before deep call tree walk by screening for Name nodes
164+ if self ._contains_myfunc_call (updated_node ):
140165 self .cfo_idx_loc_to_look_at += 1
141- matching_original_times = []
142- matching_optimized_times = []
143- # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
144- for invocation_id , runtimes in original_runtimes .items ():
145- # get position here and match in if condition
146- qualified_name = (
147- invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
148- if invocation_id .test_class_name
149- else invocation_id .test_function_name
150- )
151- rel_path = (
152- Path (invocation_id .test_module_path .replace ("." , os .sep ))
153- .with_suffix (".py" )
154- .relative_to (self .rel_tests_root )
155- )
156- if (
157- qualified_name == "." .join (self .context_stack )
158- and rel_path
159- in [
160- self .test .behavior_file_path .relative_to (self .tests_root ),
161- self .test .perf_file_path .relative_to (self .tests_root ),
162- ]
163- and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
164- ):
165- matching_original_times .extend (runtimes )
166-
167- for invocation_id , runtimes in optimized_runtimes .items ():
168- # get position here and match in if condition
169- qualified_name = (
170- invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
171- if invocation_id .test_class_name
172- else invocation_id .test_function_name
173- )
174- rel_path = (
175- Path (invocation_id .test_module_path .replace ("." , os .sep ))
176- .with_suffix (".py" )
177- .relative_to (self .rel_tests_root )
178- )
179- if (
180- qualified_name == "." .join (self .context_stack )
181- and rel_path
182- in [
183- self .test .behavior_file_path .relative_to (self .tests_root ),
184- self .test .perf_file_path .relative_to (self .tests_root ),
185- ]
186- and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
187- ):
188- matching_optimized_times .extend (runtimes )
189-
190- if matching_original_times and matching_optimized_times :
191- original_time = min (matching_original_times )
192- optimized_time = min (matching_optimized_times )
166+ if self .cfo_idx_loc_to_look_at >= len (self .cfo_locs ):
167+ return updated_node # Defensive, should never happen
168+
169+ cfo_loc = self .cfo_locs [self .cfo_idx_loc_to_look_at ]
170+
171+ qualified_name_chain = "." .join (self .context_stack )
172+ # Try both behavior and perf as possible locations; both are strings
173+ possible_paths = {self .test_rel_behavior , self .test_rel_perf }
174+
175+ # Form index key(s)
176+ matching_original = []
177+ matching_optimized = []
178+
179+ for rel_path_str in possible_paths :
180+ key = (rel_path_str , qualified_name_chain , cfo_loc )
181+ if key in orig_index :
182+ matching_original .extend (orig_index [key ])
183+ if key in opt_index :
184+ matching_optimized .extend (opt_index [key ])
185+ if matching_original and matching_optimized :
186+ original_time = min (matching_original )
187+ optimized_time = min (matching_optimized )
193188 if original_time != 0 and optimized_time != 0 :
194- perf_gain = format_perf (
189+ perf_gain_str = format_perf (
195190 abs (
196191 performance_gain (original_runtime_ns = original_time , optimized_runtime_ns = optimized_time )
197192 * 100
198193 )
199194 )
200195 status = "slower" if optimized_time > original_time else "faster"
201- # Create the runtime comment
202- comment_text = (
203- f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain } % { status } )"
204- )
196+ comment_text = f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain_str } % { status } )"
205197 return updated_node .with_changes (
206198 trailing_whitespace = cst .TrailingWhitespace (
207199 whitespace = cst .SimpleWhitespace (" " ),
@@ -211,43 +203,37 @@ def leave_SimpleStatementLine(
211203 )
212204 return updated_node
213205
214- def _contains_myfunc_call (self , node ): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001
206+ def _contains_myfunc_call (self , node ):
215207 """Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
216208
209+ # IMPORTANT micro-optimization: early abort using an exception
210+ class Found (Exception ):
211+ pass
212+
217213 class Finder (cst .CSTVisitor ):
218- def __init__ (self , name : str ) -> None :
219- super ().__init__ ()
220- self .found = False
214+ def __init__ (self , name ):
221215 self .name = name
222216
223- def visit_Call (self , call_node ) -> None : # type: ignore[no-untyped-def] # noqa : ANN001
217+ def visit_Call (self , call_node ):
224218 func_expr = call_node .func
225- if isinstance (func_expr , cst .Name ):
226- if func_expr .value == self .name :
227- self .found = True
228- elif isinstance (func_expr , cst .Attribute ): # noqa : SIM102
229- if func_expr .attr .value == self .name :
230- self .found = True
231-
232- finder = Finder (self .name )
233- node .visit (finder )
234- return finder .found
235-
236- # Process each generated test
219+ if (isinstance (func_expr , cst .Name ) and func_expr .value == self .name ) or (
220+ isinstance (func_expr , cst .Attribute ) and func_expr .attr .value == self .name
221+ ):
222+ raise Found
223+
224+ try :
225+ node .visit (Finder (self .name ))
226+ except Found :
227+ return True
228+ return False
229+
237230 modified_tests = []
238231 for test in generated_tests .generated_tests :
239232 try :
240- # Parse the test source code
241233 tree = cst .parse_module (test .generated_original_test_source )
242- # Transform the tree to add runtime comments
243- # qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
244234 transformer = RuntimeCommentTransformer (qualified_name , tree , test , tests_root , rel_tests_root )
245235 modified_tree = tree .visit (transformer )
246-
247- # Convert back to source code
248236 modified_source = modified_tree .code
249-
250- # Create a new GeneratedTests object with the modified source
251237 modified_test = GeneratedTests (
252238 generated_original_test_source = modified_source ,
253239 instrumented_behavior_test_source = test .instrumented_behavior_test_source ,
@@ -257,7 +243,6 @@ def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa
257243 )
258244 modified_tests .append (modified_test )
259245 except Exception as e :
260- # If parsing fails, keep the original test
261246 logger .debug (f"Failed to add runtime comments to test: { e } " )
262247 modified_tests .append (test )
263248
0 commit comments