Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/tutorial/supported-behaviors.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,29 @@ For more information

---

#### TYPE_CHECKING

Unimport recognizes `if TYPE_CHECKING:` blocks and skips imports inside them. These
imports only run during static analysis and are not available at runtime, so they should
never shadow or conflict with runtime imports.

```python
from qtpy import QtCore
import typing as t

if t.TYPE_CHECKING:
from PySide6 import QtCore

class MyThread(QtCore.QThread):
pass
```

In this example, unimport correctly keeps `from qtpy import QtCore` (the runtime import)
and ignores the `TYPE_CHECKING`-guarded import. Both `if TYPE_CHECKING:` and
`if typing.TYPE_CHECKING:` (or any alias like `if t.TYPE_CHECKING:`) are supported.

---

## All

Unimport looks at the items in the `__all__` list, if it matches the imports, marks it
Expand Down
30 changes: 29 additions & 1 deletion src/unimport/analyzers/import_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ImportAnalyzer(ast.NodeVisitor):
"any_import_error",
"if_names",
"orelse_names",
"_in_type_checking",
)

IGNORE_MODULES_IMPORTS = ("__future__",)
Expand All @@ -37,6 +38,7 @@ def __init__(

self.if_names: set[str] = set()
self.orelse_names: set[str] = set()
self._in_type_checking: bool = False

def traverse(self, tree) -> None:
self.visit(tree)
Expand All @@ -58,7 +60,14 @@ def visit_Import(self, node: ast.Import) -> None:
if name in self.IGNORE_IMPORT_NAMES or (name in self.if_names and name in self.orelse_names):
continue

Import.register(lineno=node.lineno, column=column + 1, name=name, package=alias.name, node=node)
Import.register(
lineno=node.lineno,
column=column + 1,
name=name,
package=alias.name,
node=node,
is_type_checking=self._in_type_checking,
)

@generic_visit
@skip_import
Expand All @@ -82,9 +91,28 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
star=is_star,
suggestions=self.get_suggestions(package) if is_star else [],
node=node,
is_type_checking=self._in_type_checking,
)

@staticmethod
def _is_type_checking_block(if_node: ast.If) -> bool:
test = if_node.test
if isinstance(test, ast.Name) and test.id == "TYPE_CHECKING":
return True
if isinstance(test, ast.Attribute) and test.attr == "TYPE_CHECKING":
return True
return False

def visit_If(self, if_node: ast.If) -> None:
if self._is_type_checking_block(if_node):
self._in_type_checking = True
for node in if_node.body:
self.visit(node)
self._in_type_checking = False
for node in if_node.orelse:
self.visit(node)
return

self.if_names = {
name.asname or name.name
for n in filter(lambda node: isinstance(node, (ast.Import, ast.ImportFrom)), if_node.body)
Expand Down
25 changes: 25 additions & 0 deletions src/unimport/analyzers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def traverse(self) -> None:
).traverse(tree)

self._deduplicate_star_suggestions()
self._cleanup_empty_type_checking()

Scope.remove_current_scope() # remove global scope

Expand All @@ -69,6 +70,30 @@ def _deduplicate_star_suggestions() -> None:
else:
seen.add(imp.name)

@staticmethod
def _cleanup_empty_type_checking() -> None:
"""If all TYPE_CHECKING-guarded imports are unused, mark TYPE_CHECKING import as unused.

Removes TYPE_CHECKING Name references so the import becomes unused,
enabling removal of both the import and the empty if-block.
"""
tc_imports = [imp for imp in Import.imports if imp.is_type_checking]
if not tc_imports:
return

if any(imp.is_used() for imp in tc_imports):
return

names_to_remove = [
name for name in Name.names if name.name == "TYPE_CHECKING" or name.name.endswith(".TYPE_CHECKING")
]
for name in names_to_remove:
Name.names.remove(name)
for scope in Scope.scopes:
if name in scope.current_nodes:
scope.current_nodes.remove(name)
break

@staticmethod
def clear():
Name.clear()
Expand Down
24 changes: 24 additions & 0 deletions src/unimport/refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,32 @@ def visit_If(self, node: cst.If) -> bool:
return True

def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.BaseStatement | cst.RemovalSentinel:
if (
self._is_type_checking_if(updated_node)
and self._is_body_pass_only(updated_node.body)
and updated_node.orelse is None
):
self._pending_lines = self._pending_lines_stack.pop()
return cst.RemoveFromParent()
return self._pop_and_apply(updated_node)

@staticmethod
def _is_type_checking_if(node: cst.If) -> bool:
if isinstance(node.test, cst.Name) and node.test.value == "TYPE_CHECKING":
return True
if isinstance(node.test, cst.Attribute) and node.test.attr.value == "TYPE_CHECKING":
return True
return False

@staticmethod
def _is_body_pass_only(body: cst.BaseSuite) -> bool:
if not isinstance(body, cst.IndentedBlock):
return False
return all(
isinstance(stmt, cst.SimpleStatementLine) and all(isinstance(s, cst.Pass) for s in stmt.body)
for stmt in body.body
)

def visit_For(self, node: cst.For) -> bool:
self._push_pending()
return True
Expand Down
21 changes: 17 additions & 4 deletions src/unimport/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Import:
package: str

node: ast.Import | ast.ImportFrom = dataclasses.field(init=False, repr=False, compare=False)
is_type_checking: bool = dataclasses.field(init=False, repr=False, compare=False, default=False)

def __len__(self) -> int:
return len(self.name.split("."))
Expand All @@ -30,7 +31,10 @@ def scope(self):

def is_used(self) -> bool:
for name in self.scope.names:
if name.match_import:
if self.is_type_checking:
if name.match_2(self):
return True
elif name.match_import:
if name.match_import == self:
return True
elif name.match(self):
Expand All @@ -43,7 +47,11 @@ def match_nearest_duplicate_import(self, name: Name) -> bool:

scope = name.scope
while scope:
imports = [_import for _import in scope.imports if name.match_2(_import) and name.lineno > _import.lineno]
imports = [
_import
for _import in scope.imports
if name.match_2(_import) and name.lineno > _import.lineno and not _import.is_type_checking
]
scope = scope.parent

if imports:
Expand All @@ -62,7 +70,7 @@ def match_nearest_duplicate_import(self, name: Name) -> bool:

@property
def is_duplicate(self) -> bool:
return [_import.name for _import in self.imports].count(self.name) > 1
return [_import.name for _import in self.imports if not _import.is_type_checking].count(self.name) > 1

@classmethod
def get_unused_imports(cls, *, include_star_import: bool = False) -> typing.Iterator[Import | ImportFrom]:
Expand All @@ -73,9 +81,12 @@ def get_unused_imports(cls, *, include_star_import: bool = False) -> typing.Iter
yield imp

@classmethod
def register(cls, *, lineno: int, column: int, name: str, package: str, node: ast.Import) -> None:
def register(
cls, *, lineno: int, column: int, name: str, package: str, node: ast.Import, is_type_checking: bool = False
) -> None:
_import = cls(lineno, column, name, package)
_import.node = node
_import.is_type_checking = is_type_checking
cls.imports.append(_import)

Scope.register(_import)
Expand Down Expand Up @@ -104,9 +115,11 @@ def register( # type: ignore[override] # noqa
star: bool,
suggestions: list[str],
node: ast.ImportFrom,
is_type_checking: bool = False,
) -> None:
_import = cls(lineno, column, name, package, star, suggestions)
_import.node = node
_import.is_type_checking = is_type_checking
cls.imports.append(_import)

Scope.register(_import)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Union

from unimport.statement import Import, ImportFrom, Name

__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]


NAMES: list[Name] = [
Name(lineno=4, name="t.TYPE_CHECKING", is_all=False),
Name(lineno=7, name="QtCore.QThread", is_all=False),
]
IMPORTS: list[Union[Import, ImportFrom]] = [
ImportFrom(
lineno=1,
column=1,
name="QtCore",
package="qtpy",
star=False,
suggestions=[],
),
Import(
lineno=2,
column=1,
name="t",
package="typing",
),
ImportFrom(
lineno=5,
column=1,
name="QtCore",
package="PySide6",
star=False,
suggestions=[],
),
]
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = []
44 changes: 44 additions & 0 deletions tests/cases/analyzer/type_variable/type_checking_unused_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Union

from unimport.statement import Import, ImportFrom, Name

__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"]


NAMES: list[Name] = []
IMPORTS: list[Union[Import, ImportFrom]] = [
ImportFrom(
lineno=1,
column=1,
name="TYPE_CHECKING",
package="typing",
star=False,
suggestions=[],
),
ImportFrom(
lineno=4,
column=1,
name="QtWebKit",
package="PyQt5",
star=False,
suggestions=[],
),
]
UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [
ImportFrom(
lineno=4,
column=1,
name="QtWebKit",
package="PyQt5",
star=False,
suggestions=[],
),
ImportFrom(
lineno=1,
column=1,
name="TYPE_CHECKING",
package="typing",
star=False,
suggestions=[],
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from qtpy import QtCore
import typing as t

if t.TYPE_CHECKING:
from PySide6 import QtCore

class MyThread(QtCore.QThread):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

class Foo:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from qtpy import QtCore
import typing as t

if t.TYPE_CHECKING:
from PySide6 import QtCore

class MyThread(QtCore.QThread):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from PyQt5 import QtWebKit

class Foo:
pass