@@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
335
335
return updated_node
336
336
337
337
338
- def extract_global_statements (source_code : str ) -> list [cst .SimpleStatementLine ]:
338
+ def extract_global_statements (source_code : str ) -> tuple [ cst . Module , list [cst .SimpleStatementLine ] ]:
339
339
"""Extract global statements from source code."""
340
340
module = cst .parse_module (source_code )
341
341
collector = GlobalStatementCollector ()
342
342
module .visit (collector )
343
- return collector .global_statements
343
+ return module , collector .global_statements
344
344
345
345
346
346
def find_last_import_line (target_code : str ) -> int :
@@ -373,39 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:
373
373
374
374
375
375
def add_global_assignments (src_module_code : str , dst_module_code : str ) -> str :
376
- new_added_global_statements = extract_global_statements (src_module_code )
377
- existing_global_statements = extract_global_statements (dst_module_code )
376
+ src_module , new_added_global_statements = extract_global_statements (src_module_code )
377
+ dst_module , existing_global_statements = extract_global_statements (dst_module_code )
378
378
379
- # make sure we don't have any staments applited multiple times in the global level.
380
- unique_global_statements = [
381
- stmt
382
- for stmt in new_added_global_statements
383
- if not any (stmt .deep_equals (existing_stmt ) for existing_stmt in existing_global_statements )
384
- ]
379
+ unique_global_statements = []
380
+ for stmt in new_added_global_statements :
381
+ if any (
382
+ stmt is existing_stmt or stmt .deep_equals (existing_stmt ) for existing_stmt in existing_global_statements
383
+ ):
384
+ continue
385
+ unique_global_statements .append (stmt )
385
386
387
+ mod_dst_code = dst_module_code
388
+ # Insert unique global statements if any
386
389
if unique_global_statements :
387
- # Find the last import line in target
388
390
last_import_line = find_last_import_line (dst_module_code )
389
-
390
- # Parse the target code
391
- target_module = cst .parse_module (dst_module_code )
392
-
393
- # Create transformer to insert new statements
391
+ # Reuse already-parsed dst_module
394
392
transformer = ImportInserter (unique_global_statements , last_import_line )
395
- #
396
- # # Apply transformation
397
- modified_module = target_module . visit ( transformer )
398
- dst_module_code = modified_module . code
399
-
400
- # Parse the code
401
- original_module = cst . parse_module ( dst_module_code )
402
- new_module = cst . parse_module ( src_module_code )
393
+ # Use visit inplace, don't parse again
394
+ modified_module = dst_module . visit ( transformer )
395
+ mod_dst_code = modified_module . code
396
+ # Parse the code after insertion
397
+ original_module = cst . parse_module ( mod_dst_code )
398
+ else :
399
+ # No new statements to insert, reuse already-parsed dst_module
400
+ original_module = dst_module
403
401
402
+ # Parse the src_module_code once only (already done above: src_module)
404
403
# Collect assignments from the new file
405
404
new_collector = GlobalAssignmentCollector ()
406
- new_module .visit (new_collector )
405
+ src_module .visit (new_collector )
406
+ # Only create transformer if there are assignments to insert/transform
407
+ if not new_collector .assignments : # nothing to transform
408
+ return mod_dst_code
407
409
408
- # Transform the original file
410
+ # Transform the original destination module
409
411
transformer = GlobalAssignmentTransformer (new_collector .assignments , new_collector .assignment_order )
410
412
transformed_module = original_module .visit (transformer )
411
413
0 commit comments