Skip to content

Commit e11b145

Browse files
hakancelikdevclaude
andcommitted
fix: preserve comments and blank lines when removing unused imports
When removing an unused import, libcst's RemoveFromParent() also removed comment lines and blank lines above the import. This change uses a two-phase approach: mark imports for removal, then in leave_SimpleStatementLine extract comment-bearing leading lines and transfer them to the next sibling statement. A stack-based push/pop mechanism isolates pending lines across compound statement boundaries (class, function, if, for, etc.) so nested statements don't consume comments from the outer scope. Closes #100 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6945e95 commit e11b145

22 files changed

+453
-12
lines changed

src/unimport/refactor.py

Lines changed: 151 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Sequence
4-
from typing import cast
4+
from typing import Union, cast
55

66
import libcst as cst
77
import libcst.matchers as m
@@ -14,14 +14,17 @@
1414

1515

1616
class _RemoveUnusedImportTransformer(cst.CSTTransformer):
17-
__slots__ = ("unused_imports",)
17+
__slots__ = ("unused_imports", "_nodes_to_remove", "_pending_lines", "_pending_lines_stack")
1818

1919
METADATA_DEPENDENCIES = (PositionProvider,)
2020

2121
def __init__(self, unused_imports: list[Import | ImportFrom]) -> None:
2222
super().__init__()
2323

2424
self.unused_imports = unused_imports
25+
self._nodes_to_remove: set[int] = set()
26+
self._pending_lines: list[cst.EmptyLine] = []
27+
self._pending_lines_stack: list[list[cst.EmptyLine]] = []
2528

2629
@staticmethod
2730
def get_import_name_from_attr(attr_node: cst.Attribute) -> str:
@@ -61,7 +64,7 @@ def get_rpar(rpar: cst.RightParen | None, location: CodeRange) -> cst.RightParen
6164

6265
def leave_import_alike(
6366
self, original_node: T.CSTImportT, updated_node: T.CSTImportT
64-
) -> cst.RemovalSentinel | T.CSTImportT:
67+
) -> T.CSTImportT:
6568
names_to_keep = []
6669
names = cast(Sequence[cst.ImportAlias], updated_node.names)
6770
# already handled by leave_ImportFrom
@@ -78,7 +81,8 @@ def leave_import_alike(
7881
if self.is_import_used(import_name, column + 1, self.get_location(original_node)):
7982
names_to_keep.append(import_alias)
8083
if not names_to_keep:
81-
return cst.RemoveFromParent()
84+
self._nodes_to_remove.add(id(original_node))
85+
return updated_node
8286
elif len(names) == len(names_to_keep):
8387
return updated_node
8488
else:
@@ -91,19 +95,17 @@ def leave_import_alike(
9195
return cast(T.CSTImportT, updated_node)
9296

9397
@staticmethod
94-
def leave_StarImport(updated_node: cst.ImportFrom, imp: ImportFrom) -> cst.ImportFrom | cst.RemovalSentinel:
98+
def leave_StarImport(updated_node: cst.ImportFrom, imp: ImportFrom) -> tuple[cst.ImportFrom, bool]:
9599
if imp.suggestions:
96100
names_to_suggestions = [cst.ImportAlias(cst.Name(module)) for module in imp.suggestions]
97-
return updated_node.with_changes(names=names_to_suggestions)
101+
return updated_node.with_changes(names=names_to_suggestions), False
98102
else:
99-
return cst.RemoveFromParent()
103+
return updated_node, True
100104

101-
def leave_Import(self, original_node: cst.Import, updated_node: cst.Import) -> cst.RemovalSentinel | cst.Import:
105+
def leave_Import(self, original_node: cst.Import, updated_node: cst.Import) -> cst.Import:
102106
return self.leave_import_alike(original_node, updated_node)
103107

104-
def leave_ImportFrom(
105-
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
106-
) -> cst.RemovalSentinel | cst.ImportFrom:
108+
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
107109
if isinstance(updated_node.names, cst.ImportStar):
108110

109111
def get_star_imp() -> ImportFrom | None:
@@ -120,12 +122,149 @@ def get_star_imp() -> ImportFrom | None:
120122

121123
imp = get_star_imp()
122124
if imp:
123-
return self.leave_StarImport(updated_node, imp)
125+
result, should_remove = self.leave_StarImport(updated_node, imp)
126+
if should_remove:
127+
self._nodes_to_remove.add(id(original_node))
128+
return result
124129
else:
125130
return original_node
126131

127132
return self.leave_import_alike(original_node, updated_node)
128133

134+
def leave_SimpleStatementLine(
135+
self,
136+
original_node: cst.SimpleStatementLine,
137+
updated_node: cst.SimpleStatementLine,
138+
) -> Union[cst.SimpleStatementLine, cst.RemovalSentinel]:
139+
# Check if any child import node was marked for removal
140+
should_remove = False
141+
for stmt in original_node.body:
142+
if id(stmt) in self._nodes_to_remove:
143+
should_remove = True
144+
break
145+
146+
if should_remove:
147+
# Extract comment-bearing lines (and blank lines that precede them)
148+
# from leading_lines and stash them
149+
lines = list(updated_node.leading_lines)
150+
preserved: list[cst.EmptyLine] = []
151+
for i, line in enumerate(lines):
152+
if isinstance(line, cst.EmptyLine) and line.comment is not None:
153+
# Also include blank lines immediately before this comment
154+
j = i - 1
155+
blank_prefix: list[cst.EmptyLine] = []
156+
while j >= 0 and isinstance(lines[j], cst.EmptyLine) and lines[j].comment is None:
157+
blank_prefix.append(lines[j])
158+
j -= 1
159+
blank_prefix.reverse()
160+
preserved.extend(blank_prefix)
161+
preserved.append(line)
162+
self._pending_lines.extend(preserved)
163+
return cst.RemoveFromParent()
164+
165+
# If there are pending comment lines, prepend them to this statement
166+
if self._pending_lines:
167+
new_leading = list(self._pending_lines) + list(updated_node.leading_lines)
168+
self._pending_lines.clear()
169+
return updated_node.with_changes(leading_lines=new_leading)
170+
171+
return updated_node
172+
173+
# -- Compound statement scope isolation --
174+
# Push/pop pending lines so that nested statements don't consume
175+
# pending lines from the outer scope.
176+
177+
def _push_pending(self) -> None:
178+
self._pending_lines_stack.append(self._pending_lines)
179+
self._pending_lines = []
180+
181+
def _pop_and_apply(self, updated_node: cst.BaseCompoundStatement) -> cst.BaseCompoundStatement:
182+
self._pending_lines = self._pending_lines_stack.pop()
183+
if self._pending_lines:
184+
new_leading = list(self._pending_lines) + list(updated_node.leading_lines)
185+
self._pending_lines.clear()
186+
return updated_node.with_changes(leading_lines=new_leading)
187+
return updated_node
188+
189+
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
190+
self._push_pending()
191+
return True
192+
193+
def leave_ClassDef(
194+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
195+
) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
196+
return self._pop_and_apply(updated_node)
197+
198+
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
199+
self._push_pending()
200+
return True
201+
202+
def leave_FunctionDef(
203+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
204+
) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
205+
return self._pop_and_apply(updated_node)
206+
207+
def visit_If(self, node: cst.If) -> bool:
208+
self._push_pending()
209+
return True
210+
211+
def leave_If(
212+
self, original_node: cst.If, updated_node: cst.If
213+
) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
214+
return self._pop_and_apply(updated_node)
215+
216+
def visit_For(self, node: cst.For) -> bool:
217+
self._push_pending()
218+
return True
219+
220+
def leave_For(
221+
self, original_node: cst.For, updated_node: cst.For
222+
) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
223+
return self._pop_and_apply(updated_node)
224+
225+
def visit_While(self, node: cst.While) -> bool:
226+
self._push_pending()
227+
return True
228+
229+
def leave_While(
230+
self, original_node: cst.While, updated_node: cst.While
231+
) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
232+
return self._pop_and_apply(updated_node)
233+
234+
def visit_Try(self, node: cst.Try) -> bool:
235+
self._push_pending()
236+
return True
237+
238+
def leave_Try(
239+
self, original_node: cst.Try, updated_node: cst.Try
240+
) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
241+
return self._pop_and_apply(updated_node)
242+
243+
def visit_TryStar(self, node: cst.TryStar) -> bool:
244+
self._push_pending()
245+
return True
246+
247+
def leave_TryStar(
248+
self, original_node: cst.TryStar, updated_node: cst.TryStar
249+
) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
250+
return self._pop_and_apply(updated_node)
251+
252+
def visit_With(self, node: cst.With) -> bool:
253+
self._push_pending()
254+
return True
255+
256+
def leave_With(
257+
self, original_node: cst.With, updated_node: cst.With
258+
) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
259+
return self._pop_and_apply(updated_node)
260+
261+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
262+
if self._pending_lines:
263+
new_footer = list(updated_node.footer) + list(self._pending_lines)
264+
self._pending_lines.clear()
265+
return updated_node.with_changes(footer=new_footer)
266+
return updated_node
267+
129268

130269
def refactor_string(source: str, unused_imports: list[Import | ImportFrom]) -> str:
131270
if unused_imports:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=9, name="print", is_all=False),
10+
Name(lineno=9, name="os", is_all=False),
11+
]
12+
IMPORTS: list[Union[Import, ImportFrom]] = [
13+
Import(lineno=1, column=1, name="os", package="os"),
14+
Import(lineno=4, column=1, name="sys", package="sys"),
15+
]
16+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
17+
Import(lineno=4, column=1, name="sys", package="sys"),
18+
]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=3, name="print", is_all=False),
10+
Name(lineno=3, name="os", is_all=False),
11+
]
12+
IMPORTS: list[Union[Import, ImportFrom]] = [
13+
Import(lineno=1, column=1, name="os", package="os"),
14+
Import(lineno=6, column=1, name="sys", package="sys"),
15+
]
16+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
17+
Import(lineno=6, column=1, name="sys", package="sys"),
18+
]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=6, name="print", is_all=False),
10+
Name(lineno=6, name="os", is_all=False),
11+
]
12+
IMPORTS: list[Union[Import, ImportFrom]] = [
13+
Import(lineno=1, column=1, name="os", package="os"),
14+
ImportFrom(
15+
lineno=4,
16+
column=1,
17+
name="compile_command",
18+
package="codeop",
19+
star=False,
20+
suggestions=[],
21+
),
22+
]
23+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
24+
ImportFrom(
25+
lineno=4,
26+
column=1,
27+
name="compile_command",
28+
package="codeop",
29+
star=False,
30+
suggestions=[],
31+
)
32+
]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=11, name="os", is_all=False),
10+
Name(lineno=11, name="Union", is_all=False),
11+
Name(lineno=11, name="driver", is_all=False),
12+
Name(lineno=11, name="Grammar", is_all=False),
13+
]
14+
IMPORTS: list[Union[Import, ImportFrom]] = [
15+
Import(lineno=2, column=1, name="os", package="os"),
16+
ImportFrom(
17+
lineno=4,
18+
column=1,
19+
name="Union",
20+
package="typing",
21+
star=False,
22+
suggestions=[],
23+
),
24+
ImportFrom(
25+
lineno=7,
26+
column=1,
27+
name="token",
28+
package=".pgen2",
29+
star=False,
30+
suggestions=[],
31+
),
32+
ImportFrom(
33+
lineno=8,
34+
column=1,
35+
name="driver",
36+
package=".pgen2",
37+
star=False,
38+
suggestions=[],
39+
),
40+
ImportFrom(
41+
lineno=10,
42+
column=1,
43+
name="Grammar",
44+
package=".pgen2.grammar",
45+
star=False,
46+
suggestions=[],
47+
),
48+
]
49+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
50+
ImportFrom(
51+
lineno=7,
52+
column=1,
53+
name="token",
54+
package=".pgen2",
55+
star=False,
56+
suggestions=[],
57+
)
58+
]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=9, name="print", is_all=False),
10+
Name(lineno=9, name="os", is_all=False),
11+
Name(lineno=10, name="print", is_all=False),
12+
Name(lineno=10, name="OrderedDict", is_all=False),
13+
]
14+
IMPORTS: list[Union[Import, ImportFrom]] = [
15+
Import(lineno=1, column=1, name="os", package="os"),
16+
Import(lineno=4, column=1, name="sys", package="sys"),
17+
ImportFrom(
18+
lineno=7,
19+
column=1,
20+
name="OrderedDict",
21+
package="collections",
22+
star=False,
23+
suggestions=[],
24+
),
25+
]
26+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
27+
Import(lineno=4, column=1, name="sys", package="sys"),
28+
]

0 commit comments

Comments
 (0)