diff --git a/codeflash/code_utils/deduplicate_code.py b/codeflash/code_utils/deduplicate_code.py index 6619579c5..3b7742f7d 100644 --- a/codeflash/code_utils/deduplicate_code.py +++ b/codeflash/code_utils/deduplicate_code.py @@ -151,7 +151,8 @@ def visit_For(self, node): def visit_With(self, node): """Handle with statement as variables""" - return self.generic_visit(node) + # micro-optimization: directly call NodeTransformer's generic_visit (fewer indirections than type-based lookup) + return ast.NodeTransformer.generic_visit(self, node) def normalize_code(code: str, remove_docstrings: bool = True) -> str: @@ -172,7 +173,7 @@ def normalize_code(code: str, remove_docstrings: bool = True) -> str: # Remove docstrings if requested if remove_docstrings: - remove_docstrings_from_ast(tree) + fast_remove_docstrings_from_ast(tree) # Normalize variable names normalizer = VariableNormalizer() @@ -233,3 +234,26 @@ def are_codes_duplicate(code1: str, code2: str) -> bool: return normalized1 == normalized2 except Exception: return False + + +def fast_remove_docstrings_from_ast(node): + """Efficiently remove docstrings from AST nodes without walking the entire tree.""" + # Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0] + node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module) + # Use our own stack-based DFS instead of ast.walk for efficiency + stack = [node] + while stack: + current_node = stack.pop() + if isinstance(current_node, node_types): + # Remove docstring if it's the first stmt in body + body = current_node.body + if ( + body + and isinstance(body[0], ast.Expr) + and isinstance(body[0].value, ast.Constant) + and isinstance(body[0].value.value, str) + ): + current_node.body = body[1:] + # Only these nodes can nest more docstring-containing nodes + # Add their body elements to stack, avoiding unnecessary traversal + stack.extend([child for child in body if isinstance(child, node_types)])