diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index ed7e6120..0de8ade7 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -32,7 +32,7 @@ def __init__( def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: self.context_stack.append(node.name) - for inner_node in ast.walk(node): + for inner_node in node.body: if isinstance(inner_node, (ast.FunctionDef, ast.AsyncFunctionDef)): self.visit_FunctionDef(inner_node) self.context_stack.pop() @@ -55,7 +55,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef: return self._process_function_def(node) - def _process_function_def(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.FunctionDef | ast.AsyncFunctionDef: + def _process_function_def( + self, node: ast.FunctionDef | ast.AsyncFunctionDef + ) -> ast.FunctionDef | ast.AsyncFunctionDef: self.context_stack.append(node.name) i = len(node.body) - 1 test_qualified_name = ".".join(self.context_stack)