Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import ast
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Set
from typing import TYPE_CHECKING, Dict, List, Optional, Set

import libcst as cst
import libcst.matchers as m
Expand All @@ -18,7 +18,6 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize

from typing import List, Union

class GlobalAssignmentCollector(cst.CSTVisitor):
"""Collects all global assignment statements."""
Expand Down Expand Up @@ -112,15 +111,11 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c

# Add the new assignments
for assignment in assignments_to_append:
new_statements.append(
cst.SimpleStatementLine(
[assignment],
leading_lines=[cst.EmptyLine()]
)
)
new_statements.append(cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]))

return updated_node.with_changes(body=new_statements)


class GlobalStatementCollector(cst.CSTVisitor):
"""Visitor that collects all global statements (excluding imports and functions/classes)."""

Expand Down Expand Up @@ -204,17 +199,14 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
def extract_global_statements(source_code: str) -> List[cst.SimpleStatementLine]:
"""Extract global statements from source code."""
module = cst.parse_module(source_code)
collector = GlobalStatementCollector()
module.visit(collector)
return collector.global_statements
return extract_global_statements_from_module(module)


def find_last_import_line(target_code: str) -> int:
"""Find the line number of the last import statement."""
module = cst.parse_module(target_code)
finder = LastImportFinder()
module.visit(finder)
return finder.last_import_line
return find_last_import_line_from_module(module)


class FutureAliasedImportTransformer(cst.CSTTransformer):
def leave_ImportFrom(
Expand All @@ -236,35 +228,34 @@ def delete___future___aliased_imports(module_code: str) -> str:


def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
non_assignment_global_statements = extract_global_statements(src_module_code)
# Parse both modules only once
src_module = cst.parse_module(src_module_code)
dst_module = cst.parse_module(dst_module_code)

# Find the last import line in target
last_import_line = find_last_import_line(dst_module_code)
# Extract statements just once, given a module
non_assignment_global_statements = extract_global_statements_from_module(src_module)

# Parse the target code
target_module = cst.parse_module(dst_module_code)
# Find the last import line, given the target module
last_import_line = find_last_import_line_from_module(dst_module)

# Create transformer to insert non_assignment_global_statements
# Insert global statements with a single transformation
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
#
# # Apply transformation
modified_module = target_module.visit(transformer)
dst_module_code = modified_module.code
modified_dst_module = dst_module.visit(transformer)
mid_dst_code = modified_dst_module.code

# Parse the code
original_module = cst.parse_module(dst_module_code)
new_module = cst.parse_module(src_module_code)
# Only parse the code after import insertion once
modified_dst_module2 = cst.parse_module(mid_dst_code)

# Collect assignments from the new file
# Collect assignments from the parsed src module
new_collector = GlobalAssignmentCollector()
new_module.visit(new_collector)
src_module.visit(new_collector)

# Transform the original file
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
transformed_module = original_module.visit(transformer)
# Transform the modified_dst_module2 (which has the extra global statements in place)
transformer2 = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
transformed_dst_module = modified_dst_module2.visit(transformer2)

dst_module_code = transformed_module.code
return dst_module_code
# Return the final code
return transformed_dst_module.code


def add_needed_imports_from_module(
Expand Down Expand Up @@ -481,3 +472,17 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),)))
return preexisting_objects


def extract_global_statements_from_module(module: cst.Module) -> List[cst.SimpleStatementLine]:
"""Extract global statements from parsed module."""
collector = GlobalStatementCollector()
module.visit(collector)
return collector.global_statements


def find_last_import_line_from_module(module: cst.Module) -> int:
"""Find the line number of the last import statement in a parsed module."""
finder = LastImportFinder()
module.visit(finder)
return finder.last_import_line
Loading