diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 0bcacc025..73532d576 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -62,11 +62,12 @@ def visit_Call(self, node): # type: ignore[no-untyped-def] # noqa: ANN201, ANN0 self.results.append(node.lineno - 1) self.generic_visit(node) - def _get_called_func_name(self, node): # type: ignore[no-untyped-def] # noqa: ANN001, ANN202 + def _get_called_func_name(self, node): """Return name of called fn.""" - if isinstance(node, ast.Name): + t = type(node) + if t is _ast_Name: return node.id - if isinstance(node, ast.Attribute): + if t is _ast_Attribute: return node.attr return None @@ -272,3 +273,8 @@ def add_runtime_comments_to_generated_tests( modified_tests.append(test) return GeneratedTestsList(generated_tests=modified_tests) + + +_ast_Name = ast.Name + +_ast_Attribute = ast.Attribute