@@ -83,6 +83,44 @@ def find_and_update_line_node(
8383 if self .function_object .is_async :
8484 return [test_node ]
8585
86+ # Create the signature binding statements
87+ bind_call = ast .Assign (
88+ targets = [ast .Name (id = "_call__bound__arguments" , ctx = ast .Store ())],
89+ value = ast .Call (
90+ func = ast .Attribute (
91+ value = ast .Call (
92+ func = ast .Attribute (
93+ value = ast .Name (id = "inspect" , ctx = ast .Load ()),
94+ attr = "signature" ,
95+ ctx = ast .Load ()
96+ ),
97+ args = [ast .Name (id = function_name , ctx = ast .Load ())],
98+ keywords = []
99+ ),
100+ attr = "bind" ,
101+ ctx = ast .Load ()
102+ ),
103+ args = [ast .Starred (value = ast .Attribute (value = call_node , attr = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
104+ keywords = [ast .keyword (arg = None , value = ast .Attribute (value = call_node , attr = "keywords" , ctx = ast .Load ()))]
105+ ),
106+ lineno = test_node .lineno if hasattr (test_node , 'lineno' ) else 1 ,
107+ col_offset = test_node .col_offset if hasattr (test_node , 'col_offset' ) else 0
108+ )
109+
110+ apply_defaults = ast .Expr (
111+ value = ast .Call (
112+ func = ast .Attribute (
113+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
114+ attr = "apply_defaults" ,
115+ ctx = ast .Load ()
116+ ),
117+ args = [],
118+ keywords = []
119+ ),
120+ lineno = test_node .lineno + 1 ,
121+ col_offset = test_node .col_offset
122+ )
123+
86124 node .func = ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
87125 node .args = [
88126 ast .Name (id = function_name , ctx = ast .Load ()),
@@ -100,14 +138,55 @@ def find_and_update_line_node(
100138 * call_node .args ,
101139 ]
102140 node .keywords = call_node .keywords
103- break
141+
142+ # Return the signature binding statements along with the test_node
143+ return [bind_call , apply_defaults , test_node ]
104144 if isinstance (node .func , ast .Attribute ):
105145 function_to_test = node .func .attr
106146 if function_to_test == self .function_object .function_name :
107147 if self .function_object .is_async :
108148 return [test_node ]
109149
110150 function_name = ast .unparse (node .func )
151+
152+ # Create the signature binding statements
153+ bind_call = ast .Assign (
154+ targets = [ast .Name (id = "_call__bound__arguments" , ctx = ast .Store ())],
155+ value = ast .Call (
156+ func = ast .Attribute (
157+ value = ast .Call (
158+ func = ast .Attribute (
159+ value = ast .Name (id = "inspect" , ctx = ast .Load ()),
160+ attr = "signature" ,
161+ ctx = ast .Load ()
162+ ),
163+ args = function_name ,
164+ keywords = []
165+ ),
166+ attr = "bind" ,
167+ ctx = ast .Load ()
168+ ),
169+ args = call_node .args ,
170+ keywords = call_node .keywords
171+ ),
172+ lineno = test_node .lineno ,
173+ col_offset = test_node .col_offset
174+ )
175+
176+ apply_defaults = ast .Expr (
177+ value = ast .Call (
178+ func = ast .Attribute (
179+ value = ast .Name (id = "_call__bound__arguments" , ctx = ast .Load ()),
180+ attr = "apply_defaults" ,
181+ ctx = ast .Load ()
182+ ),
183+ args = [],
184+ keywords = []
185+ ),
186+ lineno = test_node .lineno + 1 ,
187+ col_offset = test_node .col_offset
188+ )
189+
111190 node .func = ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
112191 node .args = [
113192 ast .Name (id = function_name , ctx = ast .Load ()),
@@ -128,11 +207,14 @@ def find_and_update_line_node(
128207 * call_node .args ,
129208 ]
130209 node .keywords = call_node .keywords
210+
211+ # Return the signature binding statements along with the test_node
212+ return_statement = [bind_call , apply_defaults , test_node ]
131213 break
132214
133215 if call_node is None :
134216 return None
135- return [ test_node ]
217+ return return_statement
136218
137219 def visit_ClassDef (self , node : ast .ClassDef ) -> ast .ClassDef :
138220 # TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
@@ -590,6 +672,7 @@ def inject_profiling_into_existing_test(
590672 ast .Import (names = [ast .alias (name = "time" )]),
591673 ast .Import (names = [ast .alias (name = "gc" )]),
592674 ast .Import (names = [ast .alias (name = "os" )]),
675+ ast .Import (names = [ast .alias (name = "inspect" )]),
593676 ]
594677 if mode == TestingMode .BEHAVIOR :
595678 new_imports .extend (
0 commit comments