@@ -32,23 +32,28 @@ def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments:
3232
3333
3434def node_in_call_position (node : ast .AST , call_positions : list [CodePosition ]) -> bool :
35- if isinstance (node , ast .Call ) and hasattr (node , "lineno" ) and hasattr (node , "col_offset" ):
36- for pos in call_positions :
37- if (
38- pos .line_no is not None
39- and node .end_lineno is not None
40- and node .lineno <= pos .line_no <= node .end_lineno
41- ):
42- if pos .line_no == node .lineno and node .col_offset <= pos .col_no :
43- return True
44- if (
45- pos .line_no == node .end_lineno
46- and node .end_col_offset is not None
47- and node .end_col_offset >= pos .col_no
48- ):
49- return True
50- if node .lineno < pos .line_no < node .end_lineno :
51- return True
35+ # Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty.
36+ # Small optimizations for tight loop:
37+ if isinstance (node , ast .Call ):
38+ node_lineno = getattr (node , "lineno" , None )
39+ node_col_offset = getattr (node , "col_offset" , None )
40+ node_end_lineno = getattr (node , "end_lineno" , None )
41+ node_end_col_offset = getattr (node , "end_col_offset" , None )
42+ if node_lineno is not None and node_col_offset is not None and node_end_lineno is not None :
43+ # Faster loop: reduce attribute lookups, use local variables for conditionals.
44+ for pos in call_positions :
45+ pos_line = pos .line_no
46+ if pos_line is not None and node_lineno <= pos_line <= node_end_lineno :
47+ if pos_line == node_lineno and node_col_offset <= pos .col_no :
48+ return True
49+ if (
50+ pos_line == node_end_lineno
51+ and node_end_col_offset is not None
52+ and node_end_col_offset >= pos .col_no
53+ ):
54+ return True
55+ if node_lineno < pos_line < node_end_lineno :
56+ return True
5257 return False
5358
5459
@@ -84,28 +89,157 @@ def __init__(
8489 def find_and_update_line_node (
8590 self , test_node : ast .stmt , node_name : str , index : str , test_class_name : str | None = None
8691 ) -> Iterable [ast .stmt ] | None :
92+ # Major optimization: since ast.walk is *very* expensive for big trees and only checks for ast.Call,
93+ # it's much more efficient to visit nodes manually. We'll only descend into expressions/statements.
94+
95+ # Helper for manual walk
96+ def iter_ast_calls (node ):
97+ # Generator to yield each ast.Call in test_node, preserves node identity
98+ stack = [node ]
99+ while stack :
100+ n = stack .pop ()
101+ if isinstance (n , ast .Call ):
102+ yield n
103+ # Instead of using ast.walk (which calls iter_child_nodes under the hood in Python, which copy lists and stack-frames for EVERY node),
104+ # do a specialized BFS with only the necessary attributes
105+ for field , value in ast .iter_fields (n ):
106+ if isinstance (value , list ):
107+ for item in reversed (value ):
108+ if isinstance (item , ast .AST ):
109+ stack .append (item )
110+ elif isinstance (value , ast .AST ):
111+ stack .append (value )
112+
113+ # This change improves from O(N) stack-frames per child-node to a single stack, less python call overhead
87114 return_statement = [test_node ]
88115 call_node = None
89- for node in ast .walk (test_node ):
90- if isinstance (node , ast .Call ) and node_in_call_position (node , self .call_positions ):
91- call_node = node
92- all_args = get_call_arguments (call_node )
93- if isinstance (node .func , ast .Name ):
94- function_name = node .func .id
95-
96- if self .function_object .is_async :
116+
117+ # Minor optimization: Convert mode, function_name, test_class_name, qualified_name, etc to locals
118+ fn_obj = self .function_object
119+ module_path = self .module_path
120+ mode = self .mode
121+ qualified_name = fn_obj .qualified_name
122+
123+ # Use locals for all 'current' values, only look up class/function/constant AST object once.
124+ codeflash_loop_index = ast .Name (id = "codeflash_loop_index" , ctx = ast .Load ())
125+ codeflash_cur = ast .Name (id = "codeflash_cur" , ctx = ast .Load ())
126+ codeflash_con = ast .Name (id = "codeflash_con" , ctx = ast .Load ())
127+
128+ for node in iter_ast_calls (test_node ):
129+ if not node_in_call_position (node , self .call_positions ):
130+ continue
131+
132+ call_node = node
133+ all_args = get_call_arguments (call_node )
134+ # Two possible call types: Name and Attribute
135+ node_func = node .func
136+
137+ if isinstance (node_func , ast .Name ):
138+ function_name = node_func .id
139+
140+ if fn_obj .is_async :
141+ return [test_node ]
142+
143+ # Build once, reuse objects.
144+ inspect_name = ast .Name (id = "inspect" , ctx = ast .Load ())
145+ bind_call = ast .Assign (
146+ targets = [ast .Name (id = "_call__bound__arguments" , ctx = ast .Store ())],
147+ value = ast .Call (
148+ func = ast .Attribute (
149+ value = ast .Call (
150+ func = ast .Attribute (value = inspect_name , attr = "signature" , ctx = ast .Load ()),
151+ args = [ast .Name (id = function_name , ctx = ast .Load ())],
152+ keywords = [],
153+ ),
154+ attr = "bind" ,
155+ ctx = ast .Load (),
156+ ),
157+ args = all_args .args ,
158+ keywords = all_args .keywords ,
159+ ),
160+ lineno = test_node .lineno ,
161+ col_offset = test_node .col_offset ,
162+ )
163+
164+ apply_defaults = ast .Expr (
165+ value = ast .Call (
166+ func = ast .Attribute (
167+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
168+ attr = "apply_defaults" ,
169+ ctx = ast .Load (),
170+ ),
171+ args = [],
172+ keywords = [],
173+ ),
174+ lineno = test_node .lineno + 1 ,
175+ col_offset = test_node .col_offset ,
176+ )
177+
178+ node .func = ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
179+ base_args = [
180+ ast .Name (id = function_name , ctx = ast .Load ()),
181+ ast .Constant (value = module_path ),
182+ ast .Constant (value = test_class_name or None ),
183+ ast .Constant (value = node_name ),
184+ ast .Constant (value = qualified_name ),
185+ ast .Constant (value = index ),
186+ codeflash_loop_index ,
187+ ]
188+ # Extend with BEHAVIOR extras if needed
189+ if mode == TestingMode .BEHAVIOR :
190+ base_args += [codeflash_cur , codeflash_con ]
191+ # Extend with call args (performance) or starred bound args (behavior)
192+ if mode == TestingMode .PERFORMANCE :
193+ base_args += call_node .args
194+ else :
195+ base_args .append (
196+ ast .Starred (
197+ value = ast .Attribute (
198+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
199+ attr = "args" ,
200+ ctx = ast .Load (),
201+ ),
202+ ctx = ast .Load (),
203+ )
204+ )
205+ node .args = base_args
206+ # Prepare keywords
207+ if mode == TestingMode .BEHAVIOR :
208+ node .keywords = [
209+ ast .keyword (
210+ value = ast .Attribute (
211+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
212+ attr = "kwargs" ,
213+ ctx = ast .Load (),
214+ )
215+ )
216+ ]
217+ else :
218+ node .keywords = call_node .keywords
219+
220+ return_statement = (
221+ [bind_call , apply_defaults , test_node ] if mode == TestingMode .BEHAVIOR else [test_node ]
222+ )
223+ break
224+ if isinstance (node_func , ast .Attribute ):
225+ function_to_test = node_func .attr
226+ if function_to_test == fn_obj .function_name :
227+ if fn_obj .is_async :
97228 return [test_node ]
98229
99230 # Create the signature binding statements
231+
232+ # Unparse only once
233+ function_name_expr = ast .parse (ast .unparse (node_func ), mode = "eval" ).body
234+
235+ inspect_name = ast .Name (id = "inspect" , ctx = ast .Load ())
100236 bind_call = ast .Assign (
101237 targets = [ast .Name (id = "_call__bound__arguments" , ctx = ast .Store ())],
102238 value = ast .Call (
103239 func = ast .Attribute (
104240 value = ast .Call (
105- func = ast .Attribute (
106- value = ast .Name (id = "inspect" , ctx = ast .Load ()), attr = "signature" , ctx = ast .Load ()
107- ),
108- args = [ast .Name (id = function_name , ctx = ast .Load ())],
241+ func = ast .Attribute (value = inspect_name , attr = "signature" , ctx = ast .Load ()),
242+ args = [function_name_expr ],
109243 keywords = [],
110244 ),
111245 attr = "bind" ,
@@ -133,36 +267,33 @@ def find_and_update_line_node(
133267 )
134268
135269 node .func = ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
136- node . args = [
137- ast . Name ( id = function_name , ctx = ast . Load ()) ,
138- ast .Constant (value = self . module_path ),
270+ base_args = [
271+ function_name_expr ,
272+ ast .Constant (value = module_path ),
139273 ast .Constant (value = test_class_name or None ),
140274 ast .Constant (value = node_name ),
141- ast .Constant (value = self . function_object . qualified_name ),
275+ ast .Constant (value = qualified_name ),
142276 ast .Constant (value = index ),
143- ast .Name (id = "codeflash_loop_index" , ctx = ast .Load ()),
144- * (
145- [ast .Name (id = "codeflash_cur" , ctx = ast .Load ()), ast .Name (id = "codeflash_con" , ctx = ast .Load ())]
146- if self .mode == TestingMode .BEHAVIOR
147- else []
148- ),
149- * (
150- call_node .args
151- if self .mode == TestingMode .PERFORMANCE
152- else [
153- ast .Starred (
154- value = ast .Attribute (
155- value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
156- attr = "args" ,
157- ctx = ast .Load (),
158- ),
159- ctx = ast .Load (),
160- )
161- ]
162- ),
277+ codeflash_loop_index ,
163278 ]
164- node .keywords = (
165- [
279+ if mode == TestingMode .BEHAVIOR :
280+ base_args += [codeflash_cur , codeflash_con ]
281+ if mode == TestingMode .PERFORMANCE :
282+ base_args += call_node .args
283+ else :
284+ base_args .append (
285+ ast .Starred (
286+ value = ast .Attribute (
287+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
288+ attr = "args" ,
289+ ctx = ast .Load (),
290+ ),
291+ ctx = ast .Load (),
292+ )
293+ )
294+ node .args = base_args
295+ if mode == TestingMode .BEHAVIOR :
296+ node .keywords = [
166297 ast .keyword (
167298 value = ast .Attribute (
168299 value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
@@ -171,112 +302,14 @@ def find_and_update_line_node(
171302 )
172303 )
173304 ]
174- if self .mode == TestingMode .BEHAVIOR
175- else call_node .keywords
176- )
305+ else :
306+ node .keywords = call_node .keywords
177307
178308 # Return the signature binding statements along with the test_node
179309 return_statement = (
180- [bind_call , apply_defaults , test_node ] if self . mode == TestingMode .BEHAVIOR else [test_node ]
310+ [bind_call , apply_defaults , test_node ] if mode == TestingMode .BEHAVIOR else [test_node ]
181311 )
182312 break
183- if isinstance (node .func , ast .Attribute ):
184- function_to_test = node .func .attr
185- if function_to_test == self .function_object .function_name :
186- if self .function_object .is_async :
187- return [test_node ]
188-
189- function_name = ast .unparse (node .func )
190-
191- # Create the signature binding statements
192- bind_call = ast .Assign (
193- targets = [ast .Name (id = "_call__bound__arguments" , ctx = ast .Store ())],
194- value = ast .Call (
195- func = ast .Attribute (
196- value = ast .Call (
197- func = ast .Attribute (
198- value = ast .Name (id = "inspect" , ctx = ast .Load ()),
199- attr = "signature" ,
200- ctx = ast .Load (),
201- ),
202- args = [ast .parse (function_name , mode = "eval" ).body ],
203- keywords = [],
204- ),
205- attr = "bind" ,
206- ctx = ast .Load (),
207- ),
208- args = all_args .args ,
209- keywords = all_args .keywords ,
210- ),
211- lineno = test_node .lineno ,
212- col_offset = test_node .col_offset ,
213- )
214-
215- apply_defaults = ast .Expr (
216- value = ast .Call (
217- func = ast .Attribute (
218- value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
219- attr = "apply_defaults" ,
220- ctx = ast .Load (),
221- ),
222- args = [],
223- keywords = [],
224- ),
225- lineno = test_node .lineno + 1 ,
226- col_offset = test_node .col_offset ,
227- )
228-
229- node .func = ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
230- node .args = [
231- ast .parse (function_name , mode = "eval" ).body ,
232- ast .Constant (value = self .module_path ),
233- ast .Constant (value = test_class_name or None ),
234- ast .Constant (value = node_name ),
235- ast .Constant (value = self .function_object .qualified_name ),
236- ast .Constant (value = index ),
237- ast .Name (id = "codeflash_loop_index" , ctx = ast .Load ()),
238- * (
239- [
240- ast .Name (id = "codeflash_cur" , ctx = ast .Load ()),
241- ast .Name (id = "codeflash_con" , ctx = ast .Load ()),
242- ]
243- if self .mode == TestingMode .BEHAVIOR
244- else []
245- ),
246- * (
247- call_node .args
248- if self .mode == TestingMode .PERFORMANCE
249- else [
250- ast .Starred (
251- value = ast .Attribute (
252- value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
253- attr = "args" ,
254- ctx = ast .Load (),
255- ),
256- ctx = ast .Load (),
257- )
258- ]
259- ),
260- ]
261- node .keywords = (
262- [
263- ast .keyword (
264- value = ast .Attribute (
265- value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
266- attr = "kwargs" ,
267- ctx = ast .Load (),
268- )
269- )
270- ]
271- if self .mode == TestingMode .BEHAVIOR
272- else call_node .keywords
273- )
274-
275- # Return the signature binding statements along with the test_node
276- return_statement = (
277- [bind_call , apply_defaults , test_node ] if self .mode == TestingMode .BEHAVIOR else [test_node ]
278- )
279- break
280313
281314 if call_node is None :
282315 return None
0 commit comments