@@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
335
335
return updated_node
336
336
337
337
338
- def extract_global_statements (source_code : str ) -> list [cst .SimpleStatementLine ]:
338
+ def extract_global_statements (source_code : str ) -> tuple [ cst . Module , list [cst .SimpleStatementLine ] ]:
339
339
"""Extract global statements from source code."""
340
340
module = cst .parse_module (source_code )
341
341
collector = GlobalStatementCollector ()
342
342
module .visit (collector )
343
- return collector .global_statements
343
+ return module , collector .global_statements
344
344
345
345
346
346
def find_last_import_line (target_code : str ) -> int :
@@ -373,16 +373,11 @@ def delete___future___aliased_imports(module_code: str) -> str:
373
373
374
374
375
375
def add_global_assignments (src_module_code : str , dst_module_code : str ) -> str :
376
- # Avoid repeat parses and visits
377
- src_module , new_added_global_statements = _extract_global_statements_once (src_module_code )
378
- dst_module , existing_global_statements = _extract_global_statements_once (dst_module_code )
376
+ src_module , new_added_global_statements = extract_global_statements (src_module_code )
377
+ dst_module , existing_global_statements = extract_global_statements (dst_module_code )
379
378
380
- # Build a list of global statements which are not already present using more efficient membership test.
381
- # Slightly optimized by making a set of (hashable deep identity) for comparison.
382
- # However, since CST nodes are not hashable, continue using deep_equals but do NOT recompute for identical object references.
383
379
unique_global_statements = []
384
380
for stmt in new_added_global_statements :
385
- # Fast path: check by id
386
381
if any (
387
382
stmt is existing_stmt or stmt .deep_equals (existing_stmt ) for existing_stmt in existing_global_statements
388
383
):
@@ -651,11 +646,3 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP
651
646
if isinstance (cnode , (ast .FunctionDef , ast .AsyncFunctionDef )):
652
647
preexisting_objects .add ((cnode .name , (FunctionParent (node .name , "ClassDef" ),)))
653
648
return preexisting_objects
654
-
655
-
656
- def _extract_global_statements_once (source_code : str ):
657
- """Extract global statements once and return both module and statements (internal)"""
658
- module = cst .parse_module (source_code )
659
- collector = GlobalStatementCollector ()
660
- module .visit (collector )
661
- return module , collector .global_statements
0 commit comments