22
33import ast
44import platform
5+ from dataclasses import dataclass
56from pathlib import Path
67from typing import TYPE_CHECKING
78
2021 from codeflash .models .models import CodePosition
2122
2223
24+ @dataclass (frozen = True )
25+ class FunctionCallNodeArguments :
26+ args : list [ast .expr ]
27+ keywords : list [ast .keyword ]
28+
29+
30+ def get_call_arguments (call_node : ast .Call ) -> FunctionCallNodeArguments :
31+ return FunctionCallNodeArguments (call_node .args , call_node .keywords )
32+
33+
2334def node_in_call_position (node : ast .AST , call_positions : list [CodePosition ]) -> bool :
2435 if isinstance (node , ast .Call ) and hasattr (node , "lineno" ) and hasattr (node , "col_offset" ):
2536 for pos in call_positions :
@@ -73,10 +84,12 @@ def __init__(
7384 def find_and_update_line_node (
7485 self , test_node : ast .stmt , node_name : str , index : str , test_class_name : str | None = None
7586 ) -> Iterable [ast .stmt ] | None :
87+ return_statement = [test_node ]
7688 call_node = None
7789 for node in ast .walk (test_node ):
7890 if isinstance (node , ast .Call ) and node_in_call_position (node , self .call_positions ):
7991 call_node = node
92+ all_args = get_call_arguments (call_node )
8093 if isinstance (node .func , ast .Name ):
8194 function_name = node .func .id
8295
@@ -90,35 +103,33 @@ def find_and_update_line_node(
90103 func = ast .Attribute (
91104 value = ast .Call (
92105 func = ast .Attribute (
93- value = ast .Name (id = "inspect" , ctx = ast .Load ()),
94- attr = "signature" ,
95- ctx = ast .Load ()
106+ value = ast .Name (id = "inspect" , ctx = ast .Load ()), attr = "signature" , ctx = ast .Load ()
96107 ),
97108 args = [ast .Name (id = function_name , ctx = ast .Load ())],
98- keywords = []
109+ keywords = [],
99110 ),
100111 attr = "bind" ,
101- ctx = ast .Load ()
112+ ctx = ast .Load (),
102113 ),
103- args = [ ast . Starred ( value = ast . Attribute ( value = call_node , attr = " args" , ctx = ast . Load ()), ctx = ast . Load ())] ,
104- keywords = [ ast . keyword ( arg = None , value = ast . Attribute ( value = call_node , attr = " keywords" , ctx = ast . Load ()))]
114+ args = all_args . args ,
115+ keywords = all_args . keywords ,
105116 ),
106- lineno = test_node .lineno if hasattr ( test_node , 'lineno' ) else 1 ,
107- col_offset = test_node .col_offset if hasattr ( test_node , 'col_offset' ) else 0
117+ lineno = test_node .lineno ,
118+ col_offset = test_node .col_offset ,
108119 )
109120
110121 apply_defaults = ast .Expr (
111122 value = ast .Call (
112123 func = ast .Attribute (
113124 value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
114125 attr = "apply_defaults" ,
115- ctx = ast .Load ()
126+ ctx = ast .Load (),
116127 ),
117128 args = [],
118- keywords = []
129+ keywords = [],
119130 ),
120131 lineno = test_node .lineno + 1 ,
121- col_offset = test_node .col_offset
132+ col_offset = test_node .col_offset ,
122133 )
123134
124135 node .func = ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
@@ -135,12 +146,40 @@ def find_and_update_line_node(
135146 if self .mode == TestingMode .BEHAVIOR
136147 else []
137148 ),
138- * call_node .args ,
149+ * (
150+ call_node .args
151+ if self .mode == TestingMode .PERFORMANCE
152+ else [
153+ ast .Starred (
154+ value = ast .Attribute (
155+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
156+ attr = "args" ,
157+ ctx = ast .Load (),
158+ ),
159+ ctx = ast .Load (),
160+ )
161+ ]
162+ ),
139163 ]
140- node .keywords = call_node .keywords
164+ node .keywords = (
165+ [
166+ ast .keyword (
167+ value = ast .Attribute (
168+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
169+ attr = "kwargs" ,
170+ ctx = ast .Load (),
171+ )
172+ )
173+ ]
174+ if self .mode == TestingMode .BEHAVIOR
175+ else call_node .keywords
176+ )
141177
142178 # Return the signature binding statements along with the test_node
143- return [bind_call , apply_defaults , test_node ]
179+ return_statement = (
180+ [bind_call , apply_defaults , test_node ] if self .mode == TestingMode .BEHAVIOR else [test_node ]
181+ )
182+ break
144183 if isinstance (node .func , ast .Attribute ):
145184 function_to_test = node .func .attr
146185 if function_to_test == self .function_object .function_name :
@@ -158,38 +197,38 @@ def find_and_update_line_node(
158197 func = ast .Attribute (
159198 value = ast .Name (id = "inspect" , ctx = ast .Load ()),
160199 attr = "signature" ,
161- ctx = ast .Load ()
200+ ctx = ast .Load (),
162201 ),
163- args = function_name ,
164- keywords = []
202+ args = [ ast . parse ( function_name , mode = "eval" ). body ] ,
203+ keywords = [],
165204 ),
166205 attr = "bind" ,
167- ctx = ast .Load ()
206+ ctx = ast .Load (),
168207 ),
169- args = call_node .args ,
170- keywords = call_node .keywords
208+ args = all_args .args ,
209+ keywords = all_args .keywords ,
171210 ),
172211 lineno = test_node .lineno ,
173- col_offset = test_node .col_offset
212+ col_offset = test_node .col_offset ,
174213 )
175214
176215 apply_defaults = ast .Expr (
177216 value = ast .Call (
178217 func = ast .Attribute (
179218 value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
180219 attr = "apply_defaults" ,
181- ctx = ast .Load ()
220+ ctx = ast .Load (),
182221 ),
183222 args = [],
184- keywords = []
223+ keywords = [],
185224 ),
186225 lineno = test_node .lineno + 1 ,
187- col_offset = test_node .col_offset
226+ col_offset = test_node .col_offset ,
188227 )
189228
190229 node .func = ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
191230 node .args = [
192- ast .Name ( id = function_name , ctx = ast . Load ()) ,
231+ ast .parse ( function_name , mode = "eval" ). body ,
193232 ast .Constant (value = self .module_path ),
194233 ast .Constant (value = test_class_name or None ),
195234 ast .Constant (value = node_name ),
@@ -204,12 +243,39 @@ def find_and_update_line_node(
204243 if self .mode == TestingMode .BEHAVIOR
205244 else []
206245 ),
207- * call_node .args ,
246+ * (
247+ call_node .args
248+ if self .mode == TestingMode .PERFORMANCE
249+ else [
250+ ast .Starred (
251+ value = ast .Attribute (
252+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
253+ attr = "args" ,
254+ ctx = ast .Load (),
255+ ),
256+ ctx = ast .Load (),
257+ )
258+ ]
259+ ),
208260 ]
209- node .keywords = call_node .keywords
261+ node .keywords = (
262+ [
263+ ast .keyword (
264+ value = ast .Attribute (
265+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
266+ attr = "kwargs" ,
267+ ctx = ast .Load (),
268+ )
269+ )
270+ ]
271+ if self .mode == TestingMode .BEHAVIOR
272+ else call_node .keywords
273+ )
210274
211275 # Return the signature binding statements along with the test_node
212- return_statement = [bind_call , apply_defaults , test_node ]
276+ return_statement = (
277+ [bind_call , apply_defaults , test_node ] if self .mode == TestingMode .BEHAVIOR else [test_node ]
278+ )
213279 break
214280
215281 if call_node is None :
0 commit comments