Skip to content

Commit d728ffd

Browse files
hakancelikdevclaude
andcommitted
Fix incorrect removal of runtime import shadowed by TYPE_CHECKING import (#313)
Register TYPE_CHECKING-guarded imports with is_type_checking flag instead of skipping them entirely. TC imports use read-only matching (match_2) so they don't interfere with runtime import resolution, and are excluded from duplicate detection. Genuinely unused TC imports are still detected and removed. When all TC-guarded imports are unused, the TYPE_CHECKING import itself and the empty if-block are also removed automatically. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a702ffb commit d728ffd

File tree

11 files changed

+224
-5
lines changed

11 files changed

+224
-5
lines changed

docs/tutorial/supported-behaviors.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,29 @@ For more information
5555

5656
---
5757

58+
#### TYPE_CHECKING
59+
60+
Unimport recognizes `if TYPE_CHECKING:` blocks and skips imports inside them. These
61+
imports only run during static analysis and are not available at runtime, so they should
62+
never shadow or conflict with runtime imports.
63+
64+
```python
65+
from qtpy import QtCore
66+
import typing as t
67+
68+
if t.TYPE_CHECKING:
69+
from PySide6 import QtCore
70+
71+
class MyThread(QtCore.QThread):
72+
pass
73+
```
74+
75+
In this example, unimport correctly keeps `from qtpy import QtCore` (the runtime import)
76+
and ignores the `TYPE_CHECKING`-guarded import. Both `if TYPE_CHECKING:` and
77+
`if typing.TYPE_CHECKING:` (or any alias like `if t.TYPE_CHECKING:`) are supported.
78+
79+
---
80+
5881
## All
5982

6083
Unimport looks at the items in the `__all__` list, if it matches the imports, marks it

src/unimport/analyzers/import_statement.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class ImportAnalyzer(ast.NodeVisitor):
2121
"any_import_error",
2222
"if_names",
2323
"orelse_names",
24+
"_in_type_checking",
2425
)
2526

2627
IGNORE_MODULES_IMPORTS = ("__future__",)
@@ -37,6 +38,7 @@ def __init__(
3738

3839
self.if_names: set[str] = set()
3940
self.orelse_names: set[str] = set()
41+
self._in_type_checking: bool = False
4042

4143
def traverse(self, tree) -> None:
4244
self.visit(tree)
@@ -58,7 +60,14 @@ def visit_Import(self, node: ast.Import) -> None:
5860
if name in self.IGNORE_IMPORT_NAMES or (name in self.if_names and name in self.orelse_names):
5961
continue
6062

61-
Import.register(lineno=node.lineno, column=column + 1, name=name, package=alias.name, node=node)
63+
Import.register(
64+
lineno=node.lineno,
65+
column=column + 1,
66+
name=name,
67+
package=alias.name,
68+
node=node,
69+
is_type_checking=self._in_type_checking,
70+
)
6271

6372
@generic_visit
6473
@skip_import
@@ -82,9 +91,28 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
8291
star=is_star,
8392
suggestions=self.get_suggestions(package) if is_star else [],
8493
node=node,
94+
is_type_checking=self._in_type_checking,
8595
)
8696

97+
@staticmethod
98+
def _is_type_checking_block(if_node: ast.If) -> bool:
99+
test = if_node.test
100+
if isinstance(test, ast.Name) and test.id == "TYPE_CHECKING":
101+
return True
102+
if isinstance(test, ast.Attribute) and test.attr == "TYPE_CHECKING":
103+
return True
104+
return False
105+
87106
def visit_If(self, if_node: ast.If) -> None:
107+
if self._is_type_checking_block(if_node):
108+
self._in_type_checking = True
109+
for node in if_node.body:
110+
self.visit(node)
111+
self._in_type_checking = False
112+
for node in if_node.orelse:
113+
self.visit(node)
114+
return
115+
88116
self.if_names = {
89117
name.asname or name.name
90118
for n in filter(lambda node: isinstance(node, (ast.Import, ast.ImportFrom)), if_node.body)

src/unimport/analyzers/main.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def traverse(self) -> None:
4545
).traverse(tree)
4646

4747
self._deduplicate_star_suggestions()
48+
self._cleanup_empty_type_checking()
4849

4950
Scope.remove_current_scope() # remove global scope
5051

@@ -69,6 +70,30 @@ def _deduplicate_star_suggestions() -> None:
6970
else:
7071
seen.add(imp.name)
7172

73+
@staticmethod
74+
def _cleanup_empty_type_checking() -> None:
75+
"""If all TYPE_CHECKING-guarded imports are unused, mark TYPE_CHECKING import as unused.
76+
77+
Removes TYPE_CHECKING Name references so the import becomes unused,
78+
enabling removal of both the import and the empty if-block.
79+
"""
80+
tc_imports = [imp for imp in Import.imports if imp.is_type_checking]
81+
if not tc_imports:
82+
return
83+
84+
if any(imp.is_used() for imp in tc_imports):
85+
return
86+
87+
names_to_remove = [
88+
name for name in Name.names if name.name == "TYPE_CHECKING" or name.name.endswith(".TYPE_CHECKING")
89+
]
90+
for name in names_to_remove:
91+
Name.names.remove(name)
92+
for scope in Scope.scopes:
93+
if name in scope.current_nodes:
94+
scope.current_nodes.remove(name)
95+
break
96+
7297
@staticmethod
7398
def clear():
7499
Name.clear()

src/unimport/refactor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,32 @@ def visit_If(self, node: cst.If) -> bool:
207207
return True
208208

209209
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.BaseStatement | cst.RemovalSentinel:
210+
if (
211+
self._is_type_checking_if(updated_node)
212+
and self._is_body_pass_only(updated_node.body)
213+
and updated_node.orelse is None
214+
):
215+
self._pending_lines = self._pending_lines_stack.pop()
216+
return cst.RemoveFromParent()
210217
return self._pop_and_apply(updated_node)
211218

219+
@staticmethod
220+
def _is_type_checking_if(node: cst.If) -> bool:
221+
if isinstance(node.test, cst.Name) and node.test.value == "TYPE_CHECKING":
222+
return True
223+
if isinstance(node.test, cst.Attribute) and node.test.attr.value == "TYPE_CHECKING":
224+
return True
225+
return False
226+
227+
@staticmethod
228+
def _is_body_pass_only(body: cst.BaseSuite) -> bool:
229+
if not isinstance(body, cst.IndentedBlock):
230+
return False
231+
return all(
232+
isinstance(stmt, cst.SimpleStatementLine) and all(isinstance(s, cst.Pass) for s in stmt.body)
233+
for stmt in body.body
234+
)
235+
212236
def visit_For(self, node: cst.For) -> bool:
213237
self._push_pending()
214238
return True

src/unimport/statement.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class Import:
1717
package: str
1818

1919
node: ast.Import | ast.ImportFrom = dataclasses.field(init=False, repr=False, compare=False)
20+
is_type_checking: bool = dataclasses.field(init=False, repr=False, compare=False, default=False)
2021

2122
def __len__(self) -> int:
2223
return len(self.name.split("."))
@@ -30,7 +31,10 @@ def scope(self):
3031

3132
def is_used(self) -> bool:
3233
for name in self.scope.names:
33-
if name.match_import:
34+
if self.is_type_checking:
35+
if name.match_2(self):
36+
return True
37+
elif name.match_import:
3438
if name.match_import == self:
3539
return True
3640
elif name.match(self):
@@ -43,7 +47,11 @@ def match_nearest_duplicate_import(self, name: Name) -> bool:
4347

4448
scope = name.scope
4549
while scope:
46-
imports = [_import for _import in scope.imports if name.match_2(_import) and name.lineno > _import.lineno]
50+
imports = [
51+
_import
52+
for _import in scope.imports
53+
if name.match_2(_import) and name.lineno > _import.lineno and not _import.is_type_checking
54+
]
4755
scope = scope.parent
4856

4957
if imports:
@@ -62,7 +70,7 @@ def match_nearest_duplicate_import(self, name: Name) -> bool:
6270

6371
@property
6472
def is_duplicate(self) -> bool:
65-
return [_import.name for _import in self.imports].count(self.name) > 1
73+
return [_import.name for _import in self.imports if not _import.is_type_checking].count(self.name) > 1
6674

6775
@classmethod
6876
def get_unused_imports(cls, *, include_star_import: bool = False) -> typing.Iterator[Import | ImportFrom]:
@@ -73,9 +81,12 @@ def get_unused_imports(cls, *, include_star_import: bool = False) -> typing.Iter
7381
yield imp
7482

7583
@classmethod
76-
def register(cls, *, lineno: int, column: int, name: str, package: str, node: ast.Import) -> None:
84+
def register(
85+
cls, *, lineno: int, column: int, name: str, package: str, node: ast.Import, is_type_checking: bool = False
86+
) -> None:
7787
_import = cls(lineno, column, name, package)
7888
_import.node = node
89+
_import.is_type_checking = is_type_checking
7990
cls.imports.append(_import)
8091

8192
Scope.register(_import)
@@ -104,9 +115,11 @@ def register( # type: ignore[override] # noqa
104115
star: bool,
105116
suggestions: list[str],
106117
node: ast.ImportFrom,
118+
is_type_checking: bool = False,
107119
) -> None:
108120
_import = cls(lineno, column, name, package, star, suggestions)
109121
_import.node = node
122+
_import.is_type_checking = is_type_checking
110123
cls.imports.append(_import)
111124

112125
Scope.register(_import)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = [
9+
Name(lineno=4, name="t.TYPE_CHECKING", is_all=False),
10+
Name(lineno=7, name="QtCore.QThread", is_all=False),
11+
]
12+
IMPORTS: list[Union[Import, ImportFrom]] = [
13+
ImportFrom(
14+
lineno=1,
15+
column=1,
16+
name="QtCore",
17+
package="qtpy",
18+
star=False,
19+
suggestions=[],
20+
),
21+
Import(
22+
lineno=2,
23+
column=1,
24+
name="t",
25+
package="typing",
26+
),
27+
ImportFrom(
28+
lineno=5,
29+
column=1,
30+
name="QtCore",
31+
package="PySide6",
32+
star=False,
33+
suggestions=[],
34+
),
35+
]
36+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = []
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Union
2+
3+
from unimport.statement import Import, ImportFrom, Name
4+
5+
__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]
6+
7+
8+
NAMES: list[Name] = []
9+
IMPORTS: list[Union[Import, ImportFrom]] = [
10+
ImportFrom(
11+
lineno=1,
12+
column=1,
13+
name="TYPE_CHECKING",
14+
package="typing",
15+
star=False,
16+
suggestions=[],
17+
),
18+
ImportFrom(
19+
lineno=4,
20+
column=1,
21+
name="QtWebKit",
22+
package="PyQt5",
23+
star=False,
24+
suggestions=[],
25+
),
26+
]
27+
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
28+
ImportFrom(
29+
lineno=4,
30+
column=1,
31+
name="QtWebKit",
32+
package="PyQt5",
33+
star=False,
34+
suggestions=[],
35+
),
36+
ImportFrom(
37+
lineno=1,
38+
column=1,
39+
name="TYPE_CHECKING",
40+
package="typing",
41+
star=False,
42+
suggestions=[],
43+
),
44+
]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from qtpy import QtCore
2+
import typing as t
3+
4+
if t.TYPE_CHECKING:
5+
from PySide6 import QtCore
6+
7+
class MyThread(QtCore.QThread):
8+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
class Foo:
3+
pass
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from qtpy import QtCore
2+
import typing as t
3+
4+
if t.TYPE_CHECKING:
5+
from PySide6 import QtCore
6+
7+
class MyThread(QtCore.QThread):
8+
pass

0 commit comments

Comments
 (0)