diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 4c50d978..52cb80a4 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c return updated_node -def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]: +def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: """Extract global statements from source code.""" module = cst.parse_module(source_code) collector = GlobalStatementCollector() module.visit(collector) - return collector.global_statements + return module, collector.global_statements def find_last_import_line(target_code: str) -> int: @@ -373,39 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str: def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: - new_added_global_statements = extract_global_statements(src_module_code) - existing_global_statements = extract_global_statements(dst_module_code) + src_module, new_added_global_statements = extract_global_statements(src_module_code) + dst_module, existing_global_statements = extract_global_statements(dst_module_code) - # make sure we don't have any staments applited multiple times in the global level. - unique_global_statements = [ - stmt - for stmt in new_added_global_statements - if not any(stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements) - ] + unique_global_statements = [] + for stmt in new_added_global_statements: + if any( + stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements + ): + continue + unique_global_statements.append(stmt) + mod_dst_code = dst_module_code + # Insert unique global statements if any if unique_global_statements: - # Find the last import line in target last_import_line = find_last_import_line(dst_module_code) - - # Parse the target code - target_module = cst.parse_module(dst_module_code) - - # Create transformer to insert new statements + # Reuse already-parsed dst_module transformer = ImportInserter(unique_global_statements, last_import_line) - # - # # Apply transformation - modified_module = target_module.visit(transformer) - dst_module_code = modified_module.code - - # Parse the code - original_module = cst.parse_module(dst_module_code) - new_module = cst.parse_module(src_module_code) + # Use visit inplace, don't parse again + modified_module = dst_module.visit(transformer) + mod_dst_code = modified_module.code + # Parse the code after insertion + original_module = cst.parse_module(mod_dst_code) + else: + # No new statements to insert, reuse already-parsed dst_module + original_module = dst_module + # Parse the src_module_code once only (already done above: src_module) # Collect assignments from the new file new_collector = GlobalAssignmentCollector() - new_module.visit(new_collector) + src_module.visit(new_collector) + # Only create transformer if there are assignments to insert/transform + if not new_collector.assignments: # nothing to transform + return mod_dst_code - # Transform the original file + # Transform the original destination module transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order) transformed_module = original_module.visit(transformer)