Skip to content

Commit 936dec2

Browse files
hakancelikdevclaude
andcommitted
fix: preserve comments and blank lines when removing unused imports
When removing an unused import, libcst's RemoveFromParent() also removed comment lines and blank lines above the import. This change uses a two-phase approach: mark imports for removal, then in leave_SimpleStatementLine extract comment-bearing leading lines and transfer them to the next sibling statement. A stack-based push/pop mechanism isolates pending lines across compound statement boundaries (class, function, if, for, etc.) so nested statements don't consume comments from the outer scope. Closes #100 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6945e95 commit 936dec2

22 files changed

+442
-13
lines changed

src/unimport/refactor.py

Lines changed: 140 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515

1616
class _RemoveUnusedImportTransformer(cst.CSTTransformer):
17-
__slots__ = ("unused_imports",)
17+
__slots__ = ("unused_imports", "_nodes_to_remove", "_pending_lines", "_pending_lines_stack")
1818

1919
METADATA_DEPENDENCIES = (PositionProvider,)
2020

2121
def __init__(self, unused_imports: list[Import | ImportFrom]) -> None:
2222
super().__init__()
2323

2424
self.unused_imports = unused_imports
25+
self._nodes_to_remove: set[int] = set()
26+
self._pending_lines: list[cst.EmptyLine] = []
27+
self._pending_lines_stack: list[list[cst.EmptyLine]] = []
2528

2629
@staticmethod
2730
def get_import_name_from_attr(attr_node: cst.Attribute) -> str:
@@ -59,9 +62,7 @@ def get_rpar(rpar: cst.RightParen | None, location: CodeRange) -> cst.RightParen
5962
else:
6063
return cst.RightParen(whitespace_before=cst.ParenthesizedWhitespace())
6164

62-
def leave_import_alike(
63-
self, original_node: T.CSTImportT, updated_node: T.CSTImportT
64-
) -> cst.RemovalSentinel | T.CSTImportT:
65+
def leave_import_alike(self, original_node: T.CSTImportT, updated_node: T.CSTImportT) -> T.CSTImportT:
6566
names_to_keep = []
6667
names = cast(Sequence[cst.ImportAlias], updated_node.names)
6768
# already handled by leave_ImportFrom
@@ -78,7 +79,8 @@ def leave_import_alike(
7879
if self.is_import_used(import_name, column + 1, self.get_location(original_node)):
7980
names_to_keep.append(import_alias)
8081
if not names_to_keep:
81-
return cst.RemoveFromParent()
82+
self._nodes_to_remove.add(id(original_node))
83+
return updated_node
8284
elif len(names) == len(names_to_keep):
8385
return updated_node
8486
else:
@@ -91,19 +93,17 @@ def leave_import_alike(
9193
return cast(T.CSTImportT, updated_node)
9294

9395
@staticmethod
94-
def leave_StarImport(updated_node: cst.ImportFrom, imp: ImportFrom) -> cst.ImportFrom | cst.RemovalSentinel:
96+
def leave_StarImport(updated_node: cst.ImportFrom, imp: ImportFrom) -> tuple[cst.ImportFrom, bool]:
9597
if imp.suggestions:
9698
names_to_suggestions = [cst.ImportAlias(cst.Name(module)) for module in imp.suggestions]
97-
return updated_node.with_changes(names=names_to_suggestions)
99+
return updated_node.with_changes(names=names_to_suggestions), False
98100
else:
99-
return cst.RemoveFromParent()
101+
return updated_node, True
100102

101-
def leave_Import(self, original_node: cst.Import, updated_node: cst.Import) -> cst.RemovalSentinel | cst.Import:
103+
def leave_Import(self, original_node: cst.Import, updated_node: cst.Import) -> cst.Import:
102104
return self.leave_import_alike(original_node, updated_node)
103105

104-
def leave_ImportFrom(
105-
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
106-
) -> cst.RemovalSentinel | cst.ImportFrom:
106+
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
107107
if isinstance(updated_node.names, cst.ImportStar):
108108

109109
def get_star_imp() -> ImportFrom | None:
@@ -120,12 +120,139 @@ def get_star_imp() -> ImportFrom | None:
120120

121121
imp = get_star_imp()
122122
if imp:
123-
return self.leave_StarImport(updated_node, imp)
123+
result, should_remove = self.leave_StarImport(updated_node, imp)
124+
if should_remove:
125+
self._nodes_to_remove.add(id(original_node))
126+
return result
124127
else:
125128
return original_node
126129

127130
return self.leave_import_alike(original_node, updated_node)
128131

132+
def leave_SimpleStatementLine(
133+
self,
134+
original_node: cst.SimpleStatementLine,
135+
updated_node: cst.SimpleStatementLine,
136+
) -> cst.SimpleStatementLine | cst.RemovalSentinel:
137+
# Check if any child import node was marked for removal
138+
should_remove = False
139+
for stmt in original_node.body:
140+
if id(stmt) in self._nodes_to_remove:
141+
should_remove = True
142+
break
143+
144+
if should_remove:
145+
# Extract comment-bearing lines (and blank lines that precede them)
146+
# from leading_lines and stash them
147+
lines = list(updated_node.leading_lines)
148+
preserved: list[cst.EmptyLine] = []
149+
for i, line in enumerate(lines):
150+
if isinstance(line, cst.EmptyLine) and line.comment is not None:
151+
# Also include blank lines immediately before this comment
152+
j = i - 1
153+
blank_prefix: list[cst.EmptyLine] = []
154+
while j >= 0 and isinstance(lines[j], cst.EmptyLine) and lines[j].comment is None:
155+
blank_prefix.append(lines[j])
156+
j -= 1
157+
blank_prefix.reverse()
158+
preserved.extend(blank_prefix)
159+
preserved.append(line)
160+
self._pending_lines.extend(preserved)
161+
return cst.RemoveFromParent()
162+
163+
# If there are pending comment lines, prepend them to this statement
164+
if self._pending_lines:
165+
new_leading = list(self._pending_lines) + list(updated_node.leading_lines)
166+
self._pending_lines.clear()
167+
return updated_node.with_changes(leading_lines=new_leading)
168+
169+
return updated_node
170+
171+
# -- Compound statement scope isolation --
172+
# Push/pop pending lines so that nested statements don't consume
173+
# pending lines from the outer scope.
174+
175+
def _push_pending(self) -> None:
176+
self._pending_lines_stack.append(self._pending_lines)
177+
self._pending_lines = []
178+
179+
def _pop_and_apply(self, updated_node: cst.BaseCompoundStatement) -> cst.BaseCompoundStatement:
180+
self._pending_lines = self._pending_lines_stack.pop()
181+
if self._pending_lines:
182+
new_leading = list(self._pending_lines) + list(updated_node.leading_lines)
183+
self._pending_lines.clear()
184+
return updated_node.with_changes(leading_lines=new_leading)
185+
return updated_node
186+
187+
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
188+
self._push_pending()
189+
return True
190+
191+
def leave_ClassDef(
192+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
193+
) -> cst.BaseStatement | cst.RemovalSentinel:
194+
return self._pop_and_apply(updated_node)
195+
196+
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
197+
self._push_pending()
198+
return True
199+
200+
def leave_FunctionDef(
201+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
202+
) -> cst.BaseStatement | cst.RemovalSentinel:
203+
return self._pop_and_apply(updated_node)
204+
205+
def visit_If(self, node: cst.If) -> bool:
206+
self._push_pending()
207+
return True
208+
209+
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.BaseStatement | cst.RemovalSentinel:
210+
return self._pop_and_apply(updated_node)
211+
212+
def visit_For(self, node: cst.For) -> bool:
213+
self._push_pending()
214+
return True
215+
216+
def leave_For(self, original_node: cst.For, updated_node: cst.For) -> cst.BaseStatement | cst.RemovalSentinel:
217+
return self._pop_and_apply(updated_node)
218+
219+
def visit_While(self, node: cst.While) -> bool:
220+
self._push_pending()
221+
return True
222+
223+
def leave_While(self, original_node: cst.While, updated_node: cst.While) -> cst.BaseStatement | cst.RemovalSentinel:
224+
return self._pop_and_apply(updated_node)
225+
226+
def visit_Try(self, node: cst.Try) -> bool:
227+
self._push_pending()
228+
return True
229+
230+
def leave_Try(self, original_node: cst.Try, updated_node: cst.Try) -> cst.BaseStatement | cst.RemovalSentinel:
231+
return self._pop_and_apply(updated_node)
232+
233+
def visit_TryStar(self, node: cst.TryStar) -> bool:
234+
self._push_pending()
235+
return True
236+
237+
def leave_TryStar(
238+
self, original_node: cst.TryStar, updated_node: cst.TryStar
239+
) -> cst.BaseStatement | cst.RemovalSentinel:
240+
return self._pop_and_apply(updated_node)
241+
242+
def visit_With(self, node: cst.With) -> bool:
243+
self._push_pending()
244+
return True
245+
246+
def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.BaseStatement | cst.RemovalSentinel:
247+
return self._pop_and_apply(updated_node)
248+
249+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
250+
if self._pending_lines:
251+
new_footer = list(updated_node.footer) + list(self._pending_lines)
252+
self._pending_lines.clear()
253+
return updated_node.with_changes(footer=new_footer)
254+
return updated_node
255+
129256

130257
def refactor_string(source: str, unused_imports: list[Import | ImportFrom]) -> str:
131258
if unused_imports:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=9, name="print", is_all=False),
10+
Name(lineno=9, name="os", is_all=False),
11+
]
12+
IMPORTS: list[Union[Import, ImportFrom]] = [
13+
Import(lineno=1, column=1, name="os", package="os"),
14+
Import(lineno=4, column=1, name="sys", package="sys"),
15+
]
16+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
17+
Import(lineno=4, column=1, name="sys", package="sys"),
18+
]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=3, name="print", is_all=False),
10+
Name(lineno=3, name="os", is_all=False),
11+
]
12+
IMPORTS: list[Union[Import, ImportFrom]] = [
13+
Import(lineno=1, column=1, name="os", package="os"),
14+
Import(lineno=6, column=1, name="sys", package="sys"),
15+
]
16+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
17+
Import(lineno=6, column=1, name="sys", package="sys"),
18+
]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=6, name="print", is_all=False),
10+
Name(lineno=6, name="os", is_all=False),
11+
]
12+
IMPORTS: list[Union[Import, ImportFrom]] = [
13+
Import(lineno=1, column=1, name="os", package="os"),
14+
ImportFrom(
15+
lineno=4,
16+
column=1,
17+
name="compile_command",
18+
package="codeop",
19+
star=False,
20+
suggestions=[],
21+
),
22+
]
23+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
24+
ImportFrom(
25+
lineno=4,
26+
column=1,
27+
name="compile_command",
28+
package="codeop",
29+
star=False,
30+
suggestions=[],
31+
)
32+
]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=11, name="os", is_all=False),
10+
Name(lineno=11, name="Union", is_all=False),
11+
Name(lineno=11, name="driver", is_all=False),
12+
Name(lineno=11, name="Grammar", is_all=False),
13+
]
14+
IMPORTS: list[Union[Import, ImportFrom]] = [
15+
Import(lineno=2, column=1, name="os", package="os"),
16+
ImportFrom(
17+
lineno=4,
18+
column=1,
19+
name="Union",
20+
package="typing",
21+
star=False,
22+
suggestions=[],
23+
),
24+
ImportFrom(
25+
lineno=7,
26+
column=1,
27+
name="token",
28+
package=".pgen2",
29+
star=False,
30+
suggestions=[],
31+
),
32+
ImportFrom(
33+
lineno=8,
34+
column=1,
35+
name="driver",
36+
package=".pgen2",
37+
star=False,
38+
suggestions=[],
39+
),
40+
ImportFrom(
41+
lineno=10,
42+
column=1,
43+
name="Grammar",
44+
package=".pgen2.grammar",
45+
star=False,
46+
suggestions=[],
47+
),
48+
]
49+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
50+
ImportFrom(
51+
lineno=7,
52+
column=1,
53+
name="token",
54+
package=".pgen2",
55+
star=False,
56+
suggestions=[],
57+
)
58+
]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=9, name="print", is_all=False),
10+
Name(lineno=9, name="os", is_all=False),
11+
Name(lineno=10, name="print", is_all=False),
12+
Name(lineno=10, name="OrderedDict", is_all=False),
13+
]
14+
IMPORTS: list[Union[Import, ImportFrom]] = [
15+
Import(lineno=1, column=1, name="os", package="os"),
16+
Import(lineno=4, column=1, name="sys", package="sys"),
17+
ImportFrom(
18+
lineno=7,
19+
column=1,
20+
name="OrderedDict",
21+
package="collections",
22+
star=False,
23+
suggestions=[],
24+
),
25+
]
26+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
27+
Import(lineno=4, column=1, name="sys", package="sys"),
28+
]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=9, name="print", is_all=False),
10+
Name(lineno=9, name="os", is_all=False),
11+
]
12+
IMPORTS: list[Union[Import, ImportFrom]] = [
13+
Import(lineno=1, column=1, name="os", package="os"),
14+
Import(lineno=4, column=1, name="sys", package="sys"),
15+
Import(lineno=7, column=1, name="json", package="json"),
16+
]
17+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
18+
Import(lineno=7, column=1, name="json", package="json"),
19+
Import(lineno=4, column=1, name="sys", package="sys"),
20+
]

0 commit comments

Comments
 (0)