1414
1515
1616class _RemoveUnusedImportTransformer (cst .CSTTransformer ):
17- __slots__ = ("unused_imports" ,)
17+ __slots__ = ("unused_imports" , "_nodes_to_remove" , "_pending_lines" , "_pending_lines_stack" )
1818
1919 METADATA_DEPENDENCIES = (PositionProvider ,)
2020
2121 def __init__ (self , unused_imports : list [Import | ImportFrom ]) -> None :
2222 super ().__init__ ()
2323
2424 self .unused_imports = unused_imports
25+ self ._nodes_to_remove : set [int ] = set ()
26+ self ._pending_lines : list [cst .EmptyLine ] = []
27+ self ._pending_lines_stack : list [list [cst .EmptyLine ]] = []
2528
2629 @staticmethod
2730 def get_import_name_from_attr (attr_node : cst .Attribute ) -> str :
@@ -59,9 +62,7 @@ def get_rpar(rpar: cst.RightParen | None, location: CodeRange) -> cst.RightParen
5962 else :
6063 return cst .RightParen (whitespace_before = cst .ParenthesizedWhitespace ())
6164
62- def leave_import_alike (
63- self , original_node : T .CSTImportT , updated_node : T .CSTImportT
64- ) -> cst .RemovalSentinel | T .CSTImportT :
65+ def leave_import_alike (self , original_node : T .CSTImportT , updated_node : T .CSTImportT ) -> T .CSTImportT :
6566 names_to_keep = []
6667 names = cast (Sequence [cst .ImportAlias ], updated_node .names )
6768 # already handled by leave_ImportFrom
@@ -78,7 +79,8 @@ def leave_import_alike(
7879 if self .is_import_used (import_name , column + 1 , self .get_location (original_node )):
7980 names_to_keep .append (import_alias )
8081 if not names_to_keep :
81- return cst .RemoveFromParent ()
82+ self ._nodes_to_remove .add (id (original_node ))
83+ return updated_node
8284 elif len (names ) == len (names_to_keep ):
8385 return updated_node
8486 else :
@@ -91,19 +93,17 @@ def leave_import_alike(
9193 return cast (T .CSTImportT , updated_node )
9294
9395 @staticmethod
94- def leave_StarImport (updated_node : cst .ImportFrom , imp : ImportFrom ) -> cst .ImportFrom | cst . RemovalSentinel :
96+ def leave_StarImport (updated_node : cst .ImportFrom , imp : ImportFrom ) -> tuple [ cst .ImportFrom , bool ] :
9597 if imp .suggestions :
9698 names_to_suggestions = [cst .ImportAlias (cst .Name (module )) for module in imp .suggestions ]
97- return updated_node .with_changes (names = names_to_suggestions )
99+ return updated_node .with_changes (names = names_to_suggestions ), False
98100 else :
99- return cst . RemoveFromParent ()
101+ return updated_node , True
100102
101- def leave_Import (self , original_node : cst .Import , updated_node : cst .Import ) -> cst .RemovalSentinel | cst . Import :
103+ def leave_Import (self , original_node : cst .Import , updated_node : cst .Import ) -> cst .Import :
102104 return self .leave_import_alike (original_node , updated_node )
103105
104- def leave_ImportFrom (
105- self , original_node : cst .ImportFrom , updated_node : cst .ImportFrom
106- ) -> cst .RemovalSentinel | cst .ImportFrom :
106+ def leave_ImportFrom (self , original_node : cst .ImportFrom , updated_node : cst .ImportFrom ) -> cst .ImportFrom :
107107 if isinstance (updated_node .names , cst .ImportStar ):
108108
109109 def get_star_imp () -> ImportFrom | None :
@@ -120,12 +120,139 @@ def get_star_imp() -> ImportFrom | None:
120120
121121 imp = get_star_imp ()
122122 if imp :
123- return self .leave_StarImport (updated_node , imp )
123+ result , should_remove = self .leave_StarImport (updated_node , imp )
124+ if should_remove :
125+ self ._nodes_to_remove .add (id (original_node ))
126+ return result
124127 else :
125128 return original_node
126129
127130 return self .leave_import_alike (original_node , updated_node )
128131
132+ def leave_SimpleStatementLine (
133+ self ,
134+ original_node : cst .SimpleStatementLine ,
135+ updated_node : cst .SimpleStatementLine ,
136+ ) -> cst .SimpleStatementLine | cst .RemovalSentinel :
137+ # Check if any child import node was marked for removal
138+ should_remove = False
139+ for stmt in original_node .body :
140+ if id (stmt ) in self ._nodes_to_remove :
141+ should_remove = True
142+ break
143+
144+ if should_remove :
145+ # Extract comment-bearing lines (and blank lines that precede them)
146+ # from leading_lines and stash them
147+ lines = list (updated_node .leading_lines )
148+ preserved : list [cst .EmptyLine ] = []
149+ for i , line in enumerate (lines ):
150+ if isinstance (line , cst .EmptyLine ) and line .comment is not None :
151+ # Also include blank lines immediately before this comment
152+ j = i - 1
153+ blank_prefix : list [cst .EmptyLine ] = []
154+ while j >= 0 and isinstance (lines [j ], cst .EmptyLine ) and lines [j ].comment is None :
155+ blank_prefix .append (lines [j ])
156+ j -= 1
157+ blank_prefix .reverse ()
158+ preserved .extend (blank_prefix )
159+ preserved .append (line )
160+ self ._pending_lines .extend (preserved )
161+ return cst .RemoveFromParent ()
162+
163+ # If there are pending comment lines, prepend them to this statement
164+ if self ._pending_lines :
165+ new_leading = list (self ._pending_lines ) + list (updated_node .leading_lines )
166+ self ._pending_lines .clear ()
167+ return updated_node .with_changes (leading_lines = new_leading )
168+
169+ return updated_node
170+
171+ # -- Compound statement scope isolation --
172+ # Push/pop pending lines so that nested statements don't consume
173+ # pending lines from the outer scope.
174+
175+ def _push_pending (self ) -> None :
176+ self ._pending_lines_stack .append (self ._pending_lines )
177+ self ._pending_lines = []
178+
179+ def _pop_and_apply (self , updated_node : cst .BaseCompoundStatement ) -> cst .BaseCompoundStatement :
180+ self ._pending_lines = self ._pending_lines_stack .pop ()
181+ if self ._pending_lines :
182+ new_leading = list (self ._pending_lines ) + list (updated_node .leading_lines )
183+ self ._pending_lines .clear ()
184+ return updated_node .with_changes (leading_lines = new_leading )
185+ return updated_node
186+
187+ def visit_ClassDef (self , node : cst .ClassDef ) -> bool :
188+ self ._push_pending ()
189+ return True
190+
191+ def leave_ClassDef (
192+ self , original_node : cst .ClassDef , updated_node : cst .ClassDef
193+ ) -> cst .BaseStatement | cst .RemovalSentinel :
194+ return self ._pop_and_apply (updated_node )
195+
196+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> bool :
197+ self ._push_pending ()
198+ return True
199+
200+ def leave_FunctionDef (
201+ self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef
202+ ) -> cst .BaseStatement | cst .RemovalSentinel :
203+ return self ._pop_and_apply (updated_node )
204+
205+ def visit_If (self , node : cst .If ) -> bool :
206+ self ._push_pending ()
207+ return True
208+
209+ def leave_If (self , original_node : cst .If , updated_node : cst .If ) -> cst .BaseStatement | cst .RemovalSentinel :
210+ return self ._pop_and_apply (updated_node )
211+
212+ def visit_For (self , node : cst .For ) -> bool :
213+ self ._push_pending ()
214+ return True
215+
216+ def leave_For (self , original_node : cst .For , updated_node : cst .For ) -> cst .BaseStatement | cst .RemovalSentinel :
217+ return self ._pop_and_apply (updated_node )
218+
219+ def visit_While (self , node : cst .While ) -> bool :
220+ self ._push_pending ()
221+ return True
222+
223+ def leave_While (self , original_node : cst .While , updated_node : cst .While ) -> cst .BaseStatement | cst .RemovalSentinel :
224+ return self ._pop_and_apply (updated_node )
225+
226+ def visit_Try (self , node : cst .Try ) -> bool :
227+ self ._push_pending ()
228+ return True
229+
230+ def leave_Try (self , original_node : cst .Try , updated_node : cst .Try ) -> cst .BaseStatement | cst .RemovalSentinel :
231+ return self ._pop_and_apply (updated_node )
232+
233+ def visit_TryStar (self , node : cst .TryStar ) -> bool :
234+ self ._push_pending ()
235+ return True
236+
237+ def leave_TryStar (
238+ self , original_node : cst .TryStar , updated_node : cst .TryStar
239+ ) -> cst .BaseStatement | cst .RemovalSentinel :
240+ return self ._pop_and_apply (updated_node )
241+
242+ def visit_With (self , node : cst .With ) -> bool :
243+ self ._push_pending ()
244+ return True
245+
246+ def leave_With (self , original_node : cst .With , updated_node : cst .With ) -> cst .BaseStatement | cst .RemovalSentinel :
247+ return self ._pop_and_apply (updated_node )
248+
249+ def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
250+ if self ._pending_lines :
251+ new_footer = list (updated_node .footer ) + list (self ._pending_lines )
252+ self ._pending_lines .clear ()
253+ return updated_node .with_changes (footer = new_footer )
254+ return updated_node
255+
129256
130257def refactor_string (source : str , unused_imports : list [Import | ImportFrom ]) -> str :
131258 if unused_imports :
0 commit comments