@@ -32,9 +32,11 @@ def __init__(
3232
3333 def visit_ClassDef (self , node : ast .ClassDef ) -> ast .ClassDef :
3434 self .context_stack .append (node .name )
35- for inner_node in ast . walk ( node ) :
35+ for inner_node in node . body :
3636 if isinstance (inner_node , ast .FunctionDef ):
3737 self .visit_FunctionDef (inner_node )
38+ elif isinstance (inner_node , ast .AsyncFunctionDef ):
39+ self .visit_AsyncFunctionDef (inner_node )
3840 self .context_stack .pop ()
3941 return node
4042
@@ -50,6 +52,14 @@ def get_comment(self, match_key: str) -> str:
5052 return f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain } % { status } )"
5153
5254 def visit_FunctionDef (self , node : ast .FunctionDef ) -> ast .FunctionDef :
55+ self ._process_function_def_common (node )
56+ return node
57+
58+ def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ) -> ast .AsyncFunctionDef :
59+ self ._process_function_def_common (node )
60+ return node
61+
62+ def _process_function_def_common (self , node : ast .FunctionDef | ast .AsyncFunctionDef ) -> None :
5363 self .context_stack .append (node .name )
5464 i = len (node .body ) - 1
5565 test_qualified_name = "." .join (self .context_stack )
@@ -60,8 +70,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
6070 j = len (line_node .body ) - 1
6171 while j >= 0 :
6272 compound_line_node : ast .stmt = line_node .body [j ]
63- internal_node : ast .AST
64- for internal_node in ast .walk (compound_line_node ):
73+ nodes_to_check = [compound_line_node ]
74+ nodes_to_check .extend (getattr (compound_line_node , "body" , []))
75+ for internal_node in nodes_to_check :
6576 if isinstance (internal_node , (ast .stmt , ast .Assign )):
6677 inv_id = str (i ) + "_" + str (j )
6778 match_key = key + "#" + inv_id
@@ -75,7 +86,6 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
7586 self .results [line_node .lineno ] = self .get_comment (match_key )
7687 i -= 1
7788 self .context_stack .pop ()
78- return node
7989
8090
8191def get_fn_call_linenos (
@@ -201,7 +211,7 @@ def remove_functions_from_generated_tests(
201211 for generated_test in generated_tests .generated_tests :
202212 for test_function in test_functions_to_remove :
203213 function_pattern = re .compile (
204- rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{ re .escape (test_function )} \(.*?\):.*?(?=\ndef \s|$)" ,
214+ rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)? def\s+{ re .escape (test_function )} \(.*?\):.*?(?=\n(async\s+)?def \s|$)" ,
205215 re .DOTALL ,
206216 )
207217
0 commit comments