@@ -32,23 +32,28 @@ def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments:
3232
3333
3434def  node_in_call_position (node : ast .AST , call_positions : list [CodePosition ]) ->  bool :
35-     if  isinstance (node , ast .Call ) and  hasattr (node , "lineno" ) and  hasattr (node , "col_offset" ):
36-         for  pos  in  call_positions :
37-             if  (
38-                 pos .line_no  is  not   None 
39-                 and  node .end_lineno  is  not   None 
40-                 and  node .lineno  <=  pos .line_no  <=  node .end_lineno 
41-             ):
42-                 if  pos .line_no  ==  node .lineno  and  node .col_offset  <=  pos .col_no :
43-                     return  True 
44-                 if  (
45-                     pos .line_no  ==  node .end_lineno 
46-                     and  node .end_col_offset  is  not   None 
47-                     and  node .end_col_offset  >=  pos .col_no 
48-                 ):
49-                     return  True 
50-                 if  node .lineno  <  pos .line_no  <  node .end_lineno :
51-                     return  True 
35+     # Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty. 
36+     # Small optimizations for tight loop: 
37+     if  isinstance (node , ast .Call ):
38+         node_lineno  =  getattr (node , "lineno" , None )
39+         node_col_offset  =  getattr (node , "col_offset" , None )
40+         node_end_lineno  =  getattr (node , "end_lineno" , None )
41+         node_end_col_offset  =  getattr (node , "end_col_offset" , None )
42+         if  node_lineno  is  not   None  and  node_col_offset  is  not   None  and  node_end_lineno  is  not   None :
43+             # Faster loop: reduce attribute lookups, use local variables for conditionals. 
44+             for  pos  in  call_positions :
45+                 pos_line  =  pos .line_no 
46+                 if  pos_line  is  not   None  and  node_lineno  <=  pos_line  <=  node_end_lineno :
47+                     if  pos_line  ==  node_lineno  and  node_col_offset  <=  pos .col_no :
48+                         return  True 
49+                     if  (
50+                         pos_line  ==  node_end_lineno 
51+                         and  node_end_col_offset  is  not   None 
52+                         and  node_end_col_offset  >=  pos .col_no 
53+                     ):
54+                         return  True 
55+                     if  node_lineno  <  pos_line  <  node_end_lineno :
56+                         return  True 
5257    return  False 
5358
5459
@@ -84,28 +89,157 @@ def __init__(
8489    def  find_and_update_line_node (
8590        self , test_node : ast .stmt , node_name : str , index : str , test_class_name : str  |  None  =  None 
8691    ) ->  Iterable [ast .stmt ] |  None :
92+         # Major optimization: since ast.walk is *very* expensive for big trees and only checks for ast.Call, 
93+         # it's much more efficient to visit nodes manually. We'll only descend into expressions/statements. 
94+ 
95+         # Helper for manual walk 
96+         def  iter_ast_calls (node ):
97+             # Generator to yield each ast.Call in test_node, preserves node identity 
98+             stack  =  [node ]
99+             while  stack :
100+                 n  =  stack .pop ()
101+                 if  isinstance (n , ast .Call ):
102+                     yield  n 
103+                 # Instead of using ast.walk (which calls iter_child_nodes under the hood in Python, which copy lists and stack-frames for EVERY node), 
104+                 # do a specialized BFS with only the necessary attributes 
105+                 for  field , value  in  ast .iter_fields (n ):
106+                     if  isinstance (value , list ):
107+                         for  item  in  reversed (value ):
108+                             if  isinstance (item , ast .AST ):
109+                                 stack .append (item )
110+                     elif  isinstance (value , ast .AST ):
111+                         stack .append (value )
112+ 
113+         # This change improves from O(N) stack-frames per child-node to a single stack, less python call overhead 
87114        return_statement  =  [test_node ]
88115        call_node  =  None 
89-         for  node  in  ast .walk (test_node ):
90-             if  isinstance (node , ast .Call ) and  node_in_call_position (node , self .call_positions ):
91-                 call_node  =  node 
92-                 all_args  =  get_call_arguments (call_node )
93-                 if  isinstance (node .func , ast .Name ):
94-                     function_name  =  node .func .id 
95- 
96-                     if  self .function_object .is_async :
116+ 
117+         # Minor optimization: Convert mode, function_name, test_class_name, qualified_name, etc to locals 
118+         fn_obj  =  self .function_object 
119+         module_path  =  self .module_path 
120+         mode  =  self .mode 
121+         qualified_name  =  fn_obj .qualified_name 
122+ 
123+         # Use locals for all 'current' values, only look up class/function/constant AST object once. 
124+         codeflash_loop_index  =  ast .Name (id = "codeflash_loop_index" , ctx = ast .Load ())
125+         codeflash_cur  =  ast .Name (id = "codeflash_cur" , ctx = ast .Load ())
126+         codeflash_con  =  ast .Name (id = "codeflash_con" , ctx = ast .Load ())
127+ 
128+         for  node  in  iter_ast_calls (test_node ):
129+             if  not  node_in_call_position (node , self .call_positions ):
130+                 continue 
131+ 
132+             call_node  =  node 
133+             all_args  =  get_call_arguments (call_node )
134+             # Two possible call types: Name and Attribute 
135+             node_func  =  node .func 
136+ 
137+             if  isinstance (node_func , ast .Name ):
138+                 function_name  =  node_func .id 
139+ 
140+                 if  fn_obj .is_async :
141+                     return  [test_node ]
142+ 
143+                 # Build once, reuse objects. 
144+                 inspect_name  =  ast .Name (id = "inspect" , ctx = ast .Load ())
145+                 bind_call  =  ast .Assign (
146+                     targets = [ast .Name (id = "_call__bound__arguments" , ctx = ast .Store ())],
147+                     value = ast .Call (
148+                         func = ast .Attribute (
149+                             value = ast .Call (
150+                                 func = ast .Attribute (value = inspect_name , attr = "signature" , ctx = ast .Load ()),
151+                                 args = [ast .Name (id = function_name , ctx = ast .Load ())],
152+                                 keywords = [],
153+                             ),
154+                             attr = "bind" ,
155+                             ctx = ast .Load (),
156+                         ),
157+                         args = all_args .args ,
158+                         keywords = all_args .keywords ,
159+                     ),
160+                     lineno = test_node .lineno ,
161+                     col_offset = test_node .col_offset ,
162+                 )
163+ 
164+                 apply_defaults  =  ast .Expr (
165+                     value = ast .Call (
166+                         func = ast .Attribute (
167+                             value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
168+                             attr = "apply_defaults" ,
169+                             ctx = ast .Load (),
170+                         ),
171+                         args = [],
172+                         keywords = [],
173+                     ),
174+                     lineno = test_node .lineno  +  1 ,
175+                     col_offset = test_node .col_offset ,
176+                 )
177+ 
178+                 node .func  =  ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
179+                 base_args  =  [
180+                     ast .Name (id = function_name , ctx = ast .Load ()),
181+                     ast .Constant (value = module_path ),
182+                     ast .Constant (value = test_class_name  or  None ),
183+                     ast .Constant (value = node_name ),
184+                     ast .Constant (value = qualified_name ),
185+                     ast .Constant (value = index ),
186+                     codeflash_loop_index ,
187+                 ]
188+                 # Extend with BEHAVIOR extras if needed 
189+                 if  mode  ==  TestingMode .BEHAVIOR :
190+                     base_args  +=  [codeflash_cur , codeflash_con ]
191+                 # Extend with call args (performance) or starred bound args (behavior) 
192+                 if  mode  ==  TestingMode .PERFORMANCE :
193+                     base_args  +=  call_node .args 
194+                 else :
195+                     base_args .append (
196+                         ast .Starred (
197+                             value = ast .Attribute (
198+                                 value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
199+                                 attr = "args" ,
200+                                 ctx = ast .Load (),
201+                             ),
202+                             ctx = ast .Load (),
203+                         )
204+                     )
205+                 node .args  =  base_args 
206+                 # Prepare keywords 
207+                 if  mode  ==  TestingMode .BEHAVIOR :
208+                     node .keywords  =  [
209+                         ast .keyword (
210+                             value = ast .Attribute (
211+                                 value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
212+                                 attr = "kwargs" ,
213+                                 ctx = ast .Load (),
214+                             )
215+                         )
216+                     ]
217+                 else :
218+                     node .keywords  =  call_node .keywords 
219+ 
220+                 return_statement  =  (
221+                     [bind_call , apply_defaults , test_node ] if  mode  ==  TestingMode .BEHAVIOR  else  [test_node ]
222+                 )
223+                 break 
224+             if  isinstance (node_func , ast .Attribute ):
225+                 function_to_test  =  node_func .attr 
226+                 if  function_to_test  ==  fn_obj .function_name :
227+                     if  fn_obj .is_async :
97228                        return  [test_node ]
98229
99230                    # Create the signature binding statements 
231+ 
232+                     # Unparse only once 
233+                     function_name_expr  =  ast .parse (ast .unparse (node_func ), mode = "eval" ).body 
234+ 
235+                     inspect_name  =  ast .Name (id = "inspect" , ctx = ast .Load ())
100236                    bind_call  =  ast .Assign (
101237                        targets = [ast .Name (id = "_call__bound__arguments" , ctx = ast .Store ())],
102238                        value = ast .Call (
103239                            func = ast .Attribute (
104240                                value = ast .Call (
105-                                     func = ast .Attribute (
106-                                         value = ast .Name (id = "inspect" , ctx = ast .Load ()), attr = "signature" , ctx = ast .Load ()
107-                                     ),
108-                                     args = [ast .Name (id = function_name , ctx = ast .Load ())],
241+                                     func = ast .Attribute (value = inspect_name , attr = "signature" , ctx = ast .Load ()),
242+                                     args = [function_name_expr ],
109243                                    keywords = [],
110244                                ),
111245                                attr = "bind" ,
@@ -133,36 +267,33 @@ def find_and_update_line_node(
133267                    )
134268
135269                    node .func  =  ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
136-                     node . args  =  [
137-                         ast . Name ( id = function_name ,  ctx = ast . Load ()) ,
138-                         ast .Constant (value = self . module_path ),
270+                     base_args  =  [
271+                         function_name_expr ,
272+                         ast .Constant (value = module_path ),
139273                        ast .Constant (value = test_class_name  or  None ),
140274                        ast .Constant (value = node_name ),
141-                         ast .Constant (value = self . function_object . qualified_name ),
275+                         ast .Constant (value = qualified_name ),
142276                        ast .Constant (value = index ),
143-                         ast .Name (id = "codeflash_loop_index" , ctx = ast .Load ()),
144-                         * (
145-                             [ast .Name (id = "codeflash_cur" , ctx = ast .Load ()), ast .Name (id = "codeflash_con" , ctx = ast .Load ())]
146-                             if  self .mode  ==  TestingMode .BEHAVIOR 
147-                             else  []
148-                         ),
149-                         * (
150-                             call_node .args 
151-                             if  self .mode  ==  TestingMode .PERFORMANCE 
152-                             else  [
153-                                 ast .Starred (
154-                                     value = ast .Attribute (
155-                                         value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
156-                                         attr = "args" ,
157-                                         ctx = ast .Load (),
158-                                     ),
159-                                     ctx = ast .Load (),
160-                                 )
161-                             ]
162-                         ),
277+                         codeflash_loop_index ,
163278                    ]
164-                     node .keywords  =  (
165-                         [
279+                     if  mode  ==  TestingMode .BEHAVIOR :
280+                         base_args  +=  [codeflash_cur , codeflash_con ]
281+                     if  mode  ==  TestingMode .PERFORMANCE :
282+                         base_args  +=  call_node .args 
283+                     else :
284+                         base_args .append (
285+                             ast .Starred (
286+                                 value = ast .Attribute (
287+                                     value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
288+                                     attr = "args" ,
289+                                     ctx = ast .Load (),
290+                                 ),
291+                                 ctx = ast .Load (),
292+                             )
293+                         )
294+                     node .args  =  base_args 
295+                     if  mode  ==  TestingMode .BEHAVIOR :
296+                         node .keywords  =  [
166297                            ast .keyword (
167298                                value = ast .Attribute (
168299                                    value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
@@ -171,112 +302,14 @@ def find_and_update_line_node(
171302                                )
172303                            )
173304                        ]
174-                         if  self .mode  ==  TestingMode .BEHAVIOR 
175-                         else  call_node .keywords 
176-                     )
305+                     else :
306+                         node .keywords  =  call_node .keywords 
177307
178308                    # Return the signature binding statements along with the test_node 
179309                    return_statement  =  (
180-                         [bind_call , apply_defaults , test_node ] if  self . mode  ==  TestingMode .BEHAVIOR  else  [test_node ]
310+                         [bind_call , apply_defaults , test_node ] if  mode  ==  TestingMode .BEHAVIOR  else  [test_node ]
181311                    )
182312                    break 
183-                 if  isinstance (node .func , ast .Attribute ):
184-                     function_to_test  =  node .func .attr 
185-                     if  function_to_test  ==  self .function_object .function_name :
186-                         if  self .function_object .is_async :
187-                             return  [test_node ]
188- 
189-                         function_name  =  ast .unparse (node .func )
190- 
191-                         # Create the signature binding statements 
192-                         bind_call  =  ast .Assign (
193-                             targets = [ast .Name (id = "_call__bound__arguments" , ctx = ast .Store ())],
194-                             value = ast .Call (
195-                                 func = ast .Attribute (
196-                                     value = ast .Call (
197-                                         func = ast .Attribute (
198-                                             value = ast .Name (id = "inspect" , ctx = ast .Load ()),
199-                                             attr = "signature" ,
200-                                             ctx = ast .Load (),
201-                                         ),
202-                                         args = [ast .parse (function_name , mode = "eval" ).body ],
203-                                         keywords = [],
204-                                     ),
205-                                     attr = "bind" ,
206-                                     ctx = ast .Load (),
207-                                 ),
208-                                 args = all_args .args ,
209-                                 keywords = all_args .keywords ,
210-                             ),
211-                             lineno = test_node .lineno ,
212-                             col_offset = test_node .col_offset ,
213-                         )
214- 
215-                         apply_defaults  =  ast .Expr (
216-                             value = ast .Call (
217-                                 func = ast .Attribute (
218-                                     value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
219-                                     attr = "apply_defaults" ,
220-                                     ctx = ast .Load (),
221-                                 ),
222-                                 args = [],
223-                                 keywords = [],
224-                             ),
225-                             lineno = test_node .lineno  +  1 ,
226-                             col_offset = test_node .col_offset ,
227-                         )
228- 
229-                         node .func  =  ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
230-                         node .args  =  [
231-                             ast .parse (function_name , mode = "eval" ).body ,
232-                             ast .Constant (value = self .module_path ),
233-                             ast .Constant (value = test_class_name  or  None ),
234-                             ast .Constant (value = node_name ),
235-                             ast .Constant (value = self .function_object .qualified_name ),
236-                             ast .Constant (value = index ),
237-                             ast .Name (id = "codeflash_loop_index" , ctx = ast .Load ()),
238-                             * (
239-                                 [
240-                                     ast .Name (id = "codeflash_cur" , ctx = ast .Load ()),
241-                                     ast .Name (id = "codeflash_con" , ctx = ast .Load ()),
242-                                 ]
243-                                 if  self .mode  ==  TestingMode .BEHAVIOR 
244-                                 else  []
245-                             ),
246-                             * (
247-                                 call_node .args 
248-                                 if  self .mode  ==  TestingMode .PERFORMANCE 
249-                                 else  [
250-                                     ast .Starred (
251-                                         value = ast .Attribute (
252-                                             value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
253-                                             attr = "args" ,
254-                                             ctx = ast .Load (),
255-                                         ),
256-                                         ctx = ast .Load (),
257-                                     )
258-                                 ]
259-                             ),
260-                         ]
261-                         node .keywords  =  (
262-                             [
263-                                 ast .keyword (
264-                                     value = ast .Attribute (
265-                                         value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
266-                                         attr = "kwargs" ,
267-                                         ctx = ast .Load (),
268-                                     )
269-                                 )
270-                             ]
271-                             if  self .mode  ==  TestingMode .BEHAVIOR 
272-                             else  call_node .keywords 
273-                         )
274- 
275-                         # Return the signature binding statements along with the test_node 
276-                         return_statement  =  (
277-                             [bind_call , apply_defaults , test_node ] if  self .mode  ==  TestingMode .BEHAVIOR  else  [test_node ]
278-                         )
279-                         break 
280313
281314        if  call_node  is  None :
282315            return  None 
0 commit comments