11from __future__ import annotations
22
33from collections .abc import Sequence
4- from typing import cast
4+ from typing import Union , cast
55
66import libcst as cst
77import libcst .matchers as m
1414
1515
1616class _RemoveUnusedImportTransformer (cst .CSTTransformer ):
17- __slots__ = ("unused_imports" ,)
17+ __slots__ = ("unused_imports" , "_nodes_to_remove" , "_pending_lines" , "_pending_lines_stack" )
1818
1919 METADATA_DEPENDENCIES = (PositionProvider ,)
2020
2121 def __init__ (self , unused_imports : list [Import | ImportFrom ]) -> None :
2222 super ().__init__ ()
2323
2424 self .unused_imports = unused_imports
25+ self ._nodes_to_remove : set [int ] = set ()
26+ self ._pending_lines : list [cst .EmptyLine ] = []
27+ self ._pending_lines_stack : list [list [cst .EmptyLine ]] = []
2528
2629 @staticmethod
2730 def get_import_name_from_attr (attr_node : cst .Attribute ) -> str :
@@ -61,7 +64,7 @@ def get_rpar(rpar: cst.RightParen | None, location: CodeRange) -> cst.RightParen
6164
6265 def leave_import_alike (
6366 self , original_node : T .CSTImportT , updated_node : T .CSTImportT
64- ) -> cst . RemovalSentinel | T .CSTImportT :
67+ ) -> T .CSTImportT :
6568 names_to_keep = []
6669 names = cast (Sequence [cst .ImportAlias ], updated_node .names )
6770 # already handled by leave_ImportFrom
@@ -78,7 +81,8 @@ def leave_import_alike(
7881 if self .is_import_used (import_name , column + 1 , self .get_location (original_node )):
7982 names_to_keep .append (import_alias )
8083 if not names_to_keep :
81- return cst .RemoveFromParent ()
84+ self ._nodes_to_remove .add (id (original_node ))
85+ return updated_node
8286 elif len (names ) == len (names_to_keep ):
8387 return updated_node
8488 else :
@@ -91,19 +95,17 @@ def leave_import_alike(
9195 return cast (T .CSTImportT , updated_node )
9296
9397 @staticmethod
94- def leave_StarImport (updated_node : cst .ImportFrom , imp : ImportFrom ) -> cst .ImportFrom | cst . RemovalSentinel :
98+ def leave_StarImport (updated_node : cst .ImportFrom , imp : ImportFrom ) -> tuple [ cst .ImportFrom , bool ] :
9599 if imp .suggestions :
96100 names_to_suggestions = [cst .ImportAlias (cst .Name (module )) for module in imp .suggestions ]
97- return updated_node .with_changes (names = names_to_suggestions )
101+ return updated_node .with_changes (names = names_to_suggestions ), False
98102 else :
99- return cst . RemoveFromParent ()
103+ return updated_node , True
100104
101- def leave_Import (self , original_node : cst .Import , updated_node : cst .Import ) -> cst .RemovalSentinel | cst . Import :
105+ def leave_Import (self , original_node : cst .Import , updated_node : cst .Import ) -> cst .Import :
102106 return self .leave_import_alike (original_node , updated_node )
103107
104- def leave_ImportFrom (
105- self , original_node : cst .ImportFrom , updated_node : cst .ImportFrom
106- ) -> cst .RemovalSentinel | cst .ImportFrom :
108+ def leave_ImportFrom (self , original_node : cst .ImportFrom , updated_node : cst .ImportFrom ) -> cst .ImportFrom :
107109 if isinstance (updated_node .names , cst .ImportStar ):
108110
109111 def get_star_imp () -> ImportFrom | None :
@@ -120,12 +122,149 @@ def get_star_imp() -> ImportFrom | None:
120122
121123 imp = get_star_imp ()
122124 if imp :
123- return self .leave_StarImport (updated_node , imp )
125+ result , should_remove = self .leave_StarImport (updated_node , imp )
126+ if should_remove :
127+ self ._nodes_to_remove .add (id (original_node ))
128+ return result
124129 else :
125130 return original_node
126131
127132 return self .leave_import_alike (original_node , updated_node )
128133
134+ def leave_SimpleStatementLine (
135+ self ,
136+ original_node : cst .SimpleStatementLine ,
137+ updated_node : cst .SimpleStatementLine ,
138+ ) -> Union [cst .SimpleStatementLine , cst .RemovalSentinel ]:
139+ # Check if any child import node was marked for removal
140+ should_remove = False
141+ for stmt in original_node .body :
142+ if id (stmt ) in self ._nodes_to_remove :
143+ should_remove = True
144+ break
145+
146+ if should_remove :
147+ # Extract comment-bearing lines (and blank lines that precede them)
148+ # from leading_lines and stash them
149+ lines = list (updated_node .leading_lines )
150+ preserved : list [cst .EmptyLine ] = []
151+ for i , line in enumerate (lines ):
152+ if isinstance (line , cst .EmptyLine ) and line .comment is not None :
153+ # Also include blank lines immediately before this comment
154+ j = i - 1
155+ blank_prefix : list [cst .EmptyLine ] = []
156+ while j >= 0 and isinstance (lines [j ], cst .EmptyLine ) and lines [j ].comment is None :
157+ blank_prefix .append (lines [j ])
158+ j -= 1
159+ blank_prefix .reverse ()
160+ preserved .extend (blank_prefix )
161+ preserved .append (line )
162+ self ._pending_lines .extend (preserved )
163+ return cst .RemoveFromParent ()
164+
165+ # If there are pending comment lines, prepend them to this statement
166+ if self ._pending_lines :
167+ new_leading = list (self ._pending_lines ) + list (updated_node .leading_lines )
168+ self ._pending_lines .clear ()
169+ return updated_node .with_changes (leading_lines = new_leading )
170+
171+ return updated_node
172+
173+ # -- Compound statement scope isolation --
174+ # Push/pop pending lines so that nested statements don't consume
175+ # pending lines from the outer scope.
176+
177+ def _push_pending (self ) -> None :
178+ self ._pending_lines_stack .append (self ._pending_lines )
179+ self ._pending_lines = []
180+
181+ def _pop_and_apply (self , updated_node : cst .BaseCompoundStatement ) -> cst .BaseCompoundStatement :
182+ self ._pending_lines = self ._pending_lines_stack .pop ()
183+ if self ._pending_lines :
184+ new_leading = list (self ._pending_lines ) + list (updated_node .leading_lines )
185+ self ._pending_lines .clear ()
186+ return updated_node .with_changes (leading_lines = new_leading )
187+ return updated_node
188+
189+ def visit_ClassDef (self , node : cst .ClassDef ) -> bool :
190+ self ._push_pending ()
191+ return True
192+
193+ def leave_ClassDef (
194+ self , original_node : cst .ClassDef , updated_node : cst .ClassDef
195+ ) -> Union [cst .BaseStatement , cst .RemovalSentinel ]:
196+ return self ._pop_and_apply (updated_node )
197+
198+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> bool :
199+ self ._push_pending ()
200+ return True
201+
202+ def leave_FunctionDef (
203+ self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef
204+ ) -> Union [cst .BaseStatement , cst .RemovalSentinel ]:
205+ return self ._pop_and_apply (updated_node )
206+
207+ def visit_If (self , node : cst .If ) -> bool :
208+ self ._push_pending ()
209+ return True
210+
211+ def leave_If (
212+ self , original_node : cst .If , updated_node : cst .If
213+ ) -> Union [cst .BaseStatement , cst .RemovalSentinel ]:
214+ return self ._pop_and_apply (updated_node )
215+
216+ def visit_For (self , node : cst .For ) -> bool :
217+ self ._push_pending ()
218+ return True
219+
220+ def leave_For (
221+ self , original_node : cst .For , updated_node : cst .For
222+ ) -> Union [cst .BaseStatement , cst .RemovalSentinel ]:
223+ return self ._pop_and_apply (updated_node )
224+
225+ def visit_While (self , node : cst .While ) -> bool :
226+ self ._push_pending ()
227+ return True
228+
229+ def leave_While (
230+ self , original_node : cst .While , updated_node : cst .While
231+ ) -> Union [cst .BaseStatement , cst .RemovalSentinel ]:
232+ return self ._pop_and_apply (updated_node )
233+
234+ def visit_Try (self , node : cst .Try ) -> bool :
235+ self ._push_pending ()
236+ return True
237+
238+ def leave_Try (
239+ self , original_node : cst .Try , updated_node : cst .Try
240+ ) -> Union [cst .BaseStatement , cst .RemovalSentinel ]:
241+ return self ._pop_and_apply (updated_node )
242+
243+ def visit_TryStar (self , node : cst .TryStar ) -> bool :
244+ self ._push_pending ()
245+ return True
246+
247+ def leave_TryStar (
248+ self , original_node : cst .TryStar , updated_node : cst .TryStar
249+ ) -> Union [cst .BaseStatement , cst .RemovalSentinel ]:
250+ return self ._pop_and_apply (updated_node )
251+
252+ def visit_With (self , node : cst .With ) -> bool :
253+ self ._push_pending ()
254+ return True
255+
256+ def leave_With (
257+ self , original_node : cst .With , updated_node : cst .With
258+ ) -> Union [cst .BaseStatement , cst .RemovalSentinel ]:
259+ return self ._pop_and_apply (updated_node )
260+
261+ def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
262+ if self ._pending_lines :
263+ new_footer = list (updated_node .footer ) + list (self ._pending_lines )
264+ self ._pending_lines .clear ()
265+ return updated_node .with_changes (footer = new_footer )
266+ return updated_node
267+
129268
130269def refactor_string (source : str , unused_imports : list [Import | ImportFrom ]) -> str :
131270 if unused_imports :
0 commit comments