@@ -74,7 +74,7 @@ def find_and_update_line_node(
7474 ) -> Iterable [ast .stmt ] | None :
7575 call_node = None
7676 await_node = None
77-
77+
7878 for node in ast .walk (test_node ):
7979 if isinstance (node , ast .Call ) and node_in_call_position (node , self .call_positions ):
8080 call_node = node
@@ -123,9 +123,13 @@ def find_and_update_line_node(
123123 ]
124124 node .keywords = call_node .keywords
125125 break
126-
126+
127127 # Check for awaited function calls
128- elif isinstance (node , ast .Await ) and isinstance (node .value , ast .Call ) and node_in_call_position (node .value , self .call_positions ):
128+ elif (
129+ isinstance (node , ast .Await )
130+ and isinstance (node .value , ast .Call )
131+ and node_in_call_position (node .value , self .call_positions )
132+ ):
129133 call_node = node .value
130134 await_node = node
131135 if isinstance (call_node .func , ast .Name ):
@@ -192,7 +196,9 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
192196
193197 return node
194198
195- def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef , test_class_name : str | None = None ) -> ast .AsyncFunctionDef :
199+ def visit_AsyncFunctionDef (
200+ self , node : ast .AsyncFunctionDef , test_class_name : str | None = None
201+ ) -> ast .AsyncFunctionDef :
196202 """Handle async function definitions by converting to sync and back."""
197203 # Convert to sync FunctionDef, process it, then convert back
198204 sync_node = ast .FunctionDef (
@@ -202,7 +208,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef, test_class_name: st
202208 decorator_list = node .decorator_list ,
203209 returns = node .returns ,
204210 lineno = node .lineno ,
205- col_offset = node .col_offset if hasattr (node , ' col_offset' ) else 0
211+ col_offset = node .col_offset if hasattr (node , " col_offset" ) else 0 ,
206212 )
207213 processed_sync = self .visit_FunctionDef (sync_node , test_class_name )
208214 # Convert back to AsyncFunctionDef
@@ -213,9 +219,9 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef, test_class_name: st
213219 decorator_list = processed_sync .decorator_list ,
214220 returns = processed_sync .returns ,
215221 lineno = processed_sync .lineno ,
216- col_offset = processed_sync .col_offset if hasattr (processed_sync , ' col_offset' ) else 0
222+ col_offset = processed_sync .col_offset if hasattr (processed_sync , " col_offset" ) else 0 ,
217223 )
218-
224+
219225 def visit_FunctionDef (self , node : ast .FunctionDef , test_class_name : str | None = None ) -> ast .FunctionDef :
220226 if node .name .startswith ("test_" ):
221227 did_update = False
0 commit comments