@@ -73,6 +73,8 @@ def find_and_update_line_node(
7373 self , test_node : ast .stmt , node_name : str , index : str , test_class_name : str | None = None
7474 ) -> Iterable [ast .stmt ] | None :
7575 call_node = None
76+ await_node = None
77+
7678 for node in ast .walk (test_node ):
7779 if isinstance (node , ast .Call ) and node_in_call_position (node , self .call_positions ):
7880 call_node = node
@@ -121,6 +123,60 @@ def find_and_update_line_node(
121123 ]
122124 node .keywords = call_node .keywords
123125 break
126+
127+ # Check for awaited function calls
128+ elif isinstance (node , ast .Await ) and isinstance (node .value , ast .Call ) and node_in_call_position (node .value , self .call_positions ):
129+ call_node = node .value
130+ await_node = node
131+ if isinstance (call_node .func , ast .Name ):
132+ function_name = call_node .func .id
133+ call_node .func = ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
134+ call_node .args = [
135+ ast .Name (id = function_name , ctx = ast .Load ()),
136+ ast .Constant (value = self .module_path ),
137+ ast .Constant (value = test_class_name or None ),
138+ ast .Constant (value = node_name ),
139+ ast .Constant (value = self .function_object .qualified_name ),
140+ ast .Constant (value = index ),
141+ ast .Name (id = "codeflash_loop_index" , ctx = ast .Load ()),
142+ * (
143+ [ast .Name (id = "codeflash_cur" , ctx = ast .Load ()), ast .Name (id = "codeflash_con" , ctx = ast .Load ())]
144+ if self .mode == TestingMode .BEHAVIOR
145+ else []
146+ ),
147+ * call_node .args ,
148+ ]
149+ call_node .keywords = call_node .keywords
150+ # Keep the await wrapper around the modified call
151+ await_node .value = call_node
152+ break
153+ if isinstance (call_node .func , ast .Attribute ):
154+ function_to_test = call_node .func .attr
155+ if function_to_test == self .function_object .function_name :
156+ function_name = ast .unparse (call_node .func )
157+ call_node .func = ast .Name (id = "codeflash_wrap" , ctx = ast .Load ())
158+ call_node .args = [
159+ ast .Name (id = function_name , ctx = ast .Load ()),
160+ ast .Constant (value = self .module_path ),
161+ ast .Constant (value = test_class_name or None ),
162+ ast .Constant (value = node_name ),
163+ ast .Constant (value = self .function_object .qualified_name ),
164+ ast .Constant (value = index ),
165+ ast .Name (id = "codeflash_loop_index" , ctx = ast .Load ()),
166+ * (
167+ [
168+ ast .Name (id = "codeflash_cur" , ctx = ast .Load ()),
169+ ast .Name (id = "codeflash_con" , ctx = ast .Load ()),
170+ ]
171+ if self .mode == TestingMode .BEHAVIOR
172+ else []
173+ ),
174+ * call_node .args ,
175+ ]
176+ call_node .keywords = call_node .keywords
177+ # Keep the await wrapper around the modified call
178+ await_node .value = call_node
179+ break
124180
125181 if call_node is None :
126182 return None
@@ -131,9 +187,35 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
131187 for inner_node in ast .walk (node ):
132188 if isinstance (inner_node , ast .FunctionDef ):
133189 self .visit_FunctionDef (inner_node , node .name )
190+ elif isinstance (inner_node , ast .AsyncFunctionDef ):
191+ self .visit_AsyncFunctionDef (inner_node , node .name )
134192
135193 return node
136194
195+ def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef , test_class_name : str | None = None ) -> ast .AsyncFunctionDef :
196+ """Handle async function definitions by converting to sync and back."""
197+ # Convert to sync FunctionDef, process it, then convert back
198+ sync_node = ast .FunctionDef (
199+ name = node .name ,
200+ args = node .args ,
201+ body = node .body ,
202+ decorator_list = node .decorator_list ,
203+ returns = node .returns ,
204+ lineno = node .lineno ,
205+ col_offset = node .col_offset if hasattr (node , 'col_offset' ) else 0
206+ )
207+ processed_sync = self .visit_FunctionDef (sync_node , test_class_name )
208+ # Convert back to AsyncFunctionDef
209+ return ast .AsyncFunctionDef (
210+ name = processed_sync .name ,
211+ args = processed_sync .args ,
212+ body = processed_sync .body ,
213+ decorator_list = processed_sync .decorator_list ,
214+ returns = processed_sync .returns ,
215+ lineno = processed_sync .lineno ,
216+ col_offset = processed_sync .col_offset if hasattr (processed_sync , 'col_offset' ) else 0
217+ )
218+
137219 def visit_FunctionDef (self , node : ast .FunctionDef , test_class_name : str | None = None ) -> ast .FunctionDef :
138220 if node .name .startswith ("test_" ):
139221 did_update = False
0 commit comments