Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions codeflash/code_utils/deduplicate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def visit_For(self, node):

def visit_With(self, node):
"""Handle with statement as variables"""
return self.generic_visit(node)
# micro-optimization: directly call NodeTransformer's generic_visit (fewer indirections than type-based lookup)
return ast.NodeTransformer.generic_visit(self, node)


def normalize_code(code: str, remove_docstrings: bool = True) -> str:
Expand All @@ -167,21 +168,44 @@ def normalize_code(code: str, remove_docstrings: bool = True) -> str:

"""
try:
# Parse the code
tree = ast.parse(code)

# Remove docstrings if requested
# Fast-path: skip docstring removal step if not requested
if remove_docstrings:
remove_docstrings_from_ast(tree)

# Normalize variable names
# Inline deduplication logic from remove_docstrings_from_ast for performance;
# replaces ast.walk() with iterative traversal for fewer allocations
nodes = [tree]
while nodes:
node = nodes.pop()
# Only consider def, async def, class, module nodes
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)):
body = node.body
if body:
expr0 = body[0]
if (
isinstance(expr0, ast.Expr)
and isinstance(expr0.value, ast.Constant)
and isinstance(expr0.value.value, str)
):
node.body = body[1:]
# Extend with children efficiently
nodes.extend(
child
for child in getattr(node, "body", [])
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))
)

# No need to import remove_docstrings_from_ast

# VariableNormalizer usage as before
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)

# Fix missing locations in the AST
# Fix missing locations efficiently
# ast.fix_missing_locations does a deep recursive update; cannot optimize without breaking API
ast.fix_missing_locations(normalized_tree)

# Unparse back to code
# ast.unparse is required; cannot avoid
return ast.unparse(normalized_tree)
except SyntaxError as e:
msg = f"Invalid Python syntax: {e}"
Expand Down
Loading