Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
return updated_node


def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
"""Extract global statements from source code."""
module = cst.parse_module(source_code)
collector = GlobalStatementCollector()
module.visit(collector)
return collector.global_statements
return module, collector.global_statements


def find_last_import_line(target_code: str) -> int:
Expand Down Expand Up @@ -373,39 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:


def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
new_added_global_statements = extract_global_statements(src_module_code)
existing_global_statements = extract_global_statements(dst_module_code)
src_module, new_added_global_statements = extract_global_statements(src_module_code)
dst_module, existing_global_statements = extract_global_statements(dst_module_code)

# make sure we don't have any staments applited multiple times in the global level.
unique_global_statements = [
stmt
for stmt in new_added_global_statements
if not any(stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements)
]
unique_global_statements = []
for stmt in new_added_global_statements:
if any(
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
):
continue
unique_global_statements.append(stmt)

mod_dst_code = dst_module_code
# Insert unique global statements if any
if unique_global_statements:
# Find the last import line in target
last_import_line = find_last_import_line(dst_module_code)

# Parse the target code
target_module = cst.parse_module(dst_module_code)

# Create transformer to insert new statements
# Reuse already-parsed dst_module
transformer = ImportInserter(unique_global_statements, last_import_line)
#
# # Apply transformation
modified_module = target_module.visit(transformer)
dst_module_code = modified_module.code

# Parse the code
original_module = cst.parse_module(dst_module_code)
new_module = cst.parse_module(src_module_code)
# Use visit inplace, don't parse again
modified_module = dst_module.visit(transformer)
mod_dst_code = modified_module.code
# Parse the code after insertion
original_module = cst.parse_module(mod_dst_code)
else:
# No new statements to insert, reuse already-parsed dst_module
original_module = dst_module

# Parse the src_module_code once only (already done above: src_module)
# Collect assignments from the new file
new_collector = GlobalAssignmentCollector()
new_module.visit(new_collector)
src_module.visit(new_collector)
# Only create transformer if there are assignments to insert/transform
if not new_collector.assignments: # nothing to transform
return mod_dst_code

# Transform the original file
# Transform the original destination module
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
transformed_module = original_module.visit(transformer)

Expand Down
Loading