diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 96b6dd845..2a129a6a1 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -2,7 +2,7 @@ import ast from pathlib import Path -from typing import TYPE_CHECKING, Dict, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set import libcst as cst import libcst.matchers as m @@ -18,7 +18,6 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from typing import List, Union class GlobalAssignmentCollector(cst.CSTVisitor): """Collects all global assignment statements.""" @@ -112,15 +111,11 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c # Add the new assignments for assignment in assignments_to_append: - new_statements.append( - cst.SimpleStatementLine( - [assignment], - leading_lines=[cst.EmptyLine()] - ) - ) + new_statements.append(cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])) return updated_node.with_changes(body=new_statements) + class GlobalStatementCollector(cst.CSTVisitor): """Visitor that collects all global statements (excluding imports and functions/classes).""" @@ -204,17 +199,14 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c def extract_global_statements(source_code: str) -> List[cst.SimpleStatementLine]: """Extract global statements from source code.""" module = cst.parse_module(source_code) - collector = GlobalStatementCollector() - module.visit(collector) - return collector.global_statements + return extract_global_statements_from_module(module) def find_last_import_line(target_code: str) -> int: """Find the line number of the last import statement.""" module = cst.parse_module(target_code) - finder = LastImportFinder() - module.visit(finder) - return finder.last_import_line + return find_last_import_line_from_module(module) + class FutureAliasedImportTransformer(cst.CSTTransformer): def leave_ImportFrom( @@ -236,35 +228,34 @@ def delete___future___aliased_imports(module_code: str) -> str: def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: - non_assignment_global_statements = extract_global_statements(src_module_code) + # Parse both modules only once + src_module = cst.parse_module(src_module_code) + dst_module = cst.parse_module(dst_module_code) - # Find the last import line in target - last_import_line = find_last_import_line(dst_module_code) + # Extract statements just once, given a module + non_assignment_global_statements = extract_global_statements_from_module(src_module) - # Parse the target code - target_module = cst.parse_module(dst_module_code) + # Find the last import line, given the target module + last_import_line = find_last_import_line_from_module(dst_module) - # Create transformer to insert non_assignment_global_statements + # Insert global statements with a single transformation transformer = ImportInserter(non_assignment_global_statements, last_import_line) - # - # # Apply transformation - modified_module = target_module.visit(transformer) - dst_module_code = modified_module.code + modified_dst_module = dst_module.visit(transformer) + mid_dst_code = modified_dst_module.code - # Parse the code - original_module = cst.parse_module(dst_module_code) - new_module = cst.parse_module(src_module_code) + # Only parse the code after import insertion once + modified_dst_module2 = cst.parse_module(mid_dst_code) - # Collect assignments from the new file + # Collect assignments from the parsed src module new_collector = GlobalAssignmentCollector() - new_module.visit(new_collector) + src_module.visit(new_collector) - # Transform the original file - transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order) - transformed_module = original_module.visit(transformer) + # Transform the modified_dst_module2 (which has the extra global statements in place) + transformer2 = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order) + transformed_dst_module = modified_dst_module2.visit(transformer2) - dst_module_code = transformed_module.code - return dst_module_code + # Return the final code + return transformed_dst_module.code def add_needed_imports_from_module( @@ -481,3 +472,17 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),))) return preexisting_objects + + +def extract_global_statements_from_module(module: cst.Module) -> List[cst.SimpleStatementLine]: + """Extract global statements from parsed module.""" + collector = GlobalStatementCollector() + module.visit(collector) + return collector.global_statements + + +def find_last_import_line_from_module(module: cst.Module) -> int: + """Find the line number of the last import statement in a parsed module.""" + finder = LastImportFinder() + module.visit(finder) + return finder.last_import_line