diff --git a/docs/tutorial/supported-behaviors.md b/docs/tutorial/supported-behaviors.md index 0be33e4c..9053b7c2 100644 --- a/docs/tutorial/supported-behaviors.md +++ b/docs/tutorial/supported-behaviors.md @@ -55,6 +55,29 @@ For more information --- +#### TYPE_CHECKING + +Unimport recognizes `if TYPE_CHECKING:` blocks and skips imports inside them. These +imports only run during static analysis and are not available at runtime, so they should +never shadow or conflict with runtime imports. + +```python +from qtpy import QtCore +import typing as t + +if t.TYPE_CHECKING: + from PySide6 import QtCore + +class MyThread(QtCore.QThread): + pass +``` + +In this example, unimport correctly keeps `from qtpy import QtCore` (the runtime import) +and ignores the `TYPE_CHECKING`-guarded import. Both `if TYPE_CHECKING:` and +`if typing.TYPE_CHECKING:` (or any alias like `if t.TYPE_CHECKING:`) are supported. + +--- + ## All Unimport looks at the items in the `__all__` list, if it matches the imports, marks it diff --git a/src/unimport/analyzers/import_statement.py b/src/unimport/analyzers/import_statement.py index 077b677e..245f7a84 100644 --- a/src/unimport/analyzers/import_statement.py +++ b/src/unimport/analyzers/import_statement.py @@ -21,6 +21,7 @@ class ImportAnalyzer(ast.NodeVisitor): "any_import_error", "if_names", "orelse_names", + "_in_type_checking", ) IGNORE_MODULES_IMPORTS = ("__future__",) @@ -37,6 +38,7 @@ def __init__( self.if_names: set[str] = set() self.orelse_names: set[str] = set() + self._in_type_checking: bool = False def traverse(self, tree) -> None: self.visit(tree) @@ -58,7 +60,14 @@ def visit_Import(self, node: ast.Import) -> None: if name in self.IGNORE_IMPORT_NAMES or (name in self.if_names and name in self.orelse_names): continue - Import.register(lineno=node.lineno, column=column + 1, name=name, package=alias.name, node=node) + Import.register( + lineno=node.lineno, + column=column + 1, + name=name, + package=alias.name, + node=node, + is_type_checking=self._in_type_checking, + ) @generic_visit @skip_import @@ -82,9 +91,28 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: star=is_star, suggestions=self.get_suggestions(package) if is_star else [], node=node, + is_type_checking=self._in_type_checking, ) + @staticmethod + def _is_type_checking_block(if_node: ast.If) -> bool: + test = if_node.test + if isinstance(test, ast.Name) and test.id == "TYPE_CHECKING": + return True + if isinstance(test, ast.Attribute) and test.attr == "TYPE_CHECKING": + return True + return False + def visit_If(self, if_node: ast.If) -> None: + if self._is_type_checking_block(if_node): + self._in_type_checking = True + for node in if_node.body: + self.visit(node) + self._in_type_checking = False + for node in if_node.orelse: + self.visit(node) + return + self.if_names = { name.asname or name.name for n in filter(lambda node: isinstance(node, (ast.Import, ast.ImportFrom)), if_node.body) diff --git a/src/unimport/analyzers/main.py b/src/unimport/analyzers/main.py index 5d5eaa3a..91c51090 100644 --- a/src/unimport/analyzers/main.py +++ b/src/unimport/analyzers/main.py @@ -45,6 +45,7 @@ def traverse(self) -> None: ).traverse(tree) self._deduplicate_star_suggestions() + self._cleanup_empty_type_checking() Scope.remove_current_scope() # remove global scope @@ -69,6 +70,30 @@ def _deduplicate_star_suggestions() -> None: else: seen.add(imp.name) + @staticmethod + def _cleanup_empty_type_checking() -> None: + """If all TYPE_CHECKING-guarded imports are unused, mark TYPE_CHECKING import as unused. + + Removes TYPE_CHECKING Name references so the import becomes unused, + enabling removal of both the import and the empty if-block. + """ + tc_imports = [imp for imp in Import.imports if imp.is_type_checking] + if not tc_imports: + return + + if any(imp.is_used() for imp in tc_imports): + return + + names_to_remove = [ + name for name in Name.names if name.name == "TYPE_CHECKING" or name.name.endswith(".TYPE_CHECKING") + ] + for name in names_to_remove: + Name.names.remove(name) + for scope in Scope.scopes: + if name in scope.current_nodes: + scope.current_nodes.remove(name) + break + @staticmethod def clear(): Name.clear() diff --git a/src/unimport/refactor.py b/src/unimport/refactor.py index b8dc7dc2..63cac371 100644 --- a/src/unimport/refactor.py +++ b/src/unimport/refactor.py @@ -207,8 +207,32 @@ def visit_If(self, node: cst.If) -> bool: return True def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.BaseStatement | cst.RemovalSentinel: + if ( + self._is_type_checking_if(updated_node) + and self._is_body_pass_only(updated_node.body) + and updated_node.orelse is None + ): + self._pending_lines = self._pending_lines_stack.pop() + return cst.RemoveFromParent() return self._pop_and_apply(updated_node) + @staticmethod + def _is_type_checking_if(node: cst.If) -> bool: + if isinstance(node.test, cst.Name) and node.test.value == "TYPE_CHECKING": + return True + if isinstance(node.test, cst.Attribute) and node.test.attr.value == "TYPE_CHECKING": + return True + return False + + @staticmethod + def _is_body_pass_only(body: cst.BaseSuite) -> bool: + if not isinstance(body, cst.IndentedBlock): + return False + return all( + isinstance(stmt, cst.SimpleStatementLine) and all(isinstance(s, cst.Pass) for s in stmt.body) + for stmt in body.body + ) + def visit_For(self, node: cst.For) -> bool: self._push_pending() return True diff --git a/src/unimport/statement.py b/src/unimport/statement.py index 29e655e4..5d5973b3 100644 --- a/src/unimport/statement.py +++ b/src/unimport/statement.py @@ -17,6 +17,7 @@ class Import: package: str node: ast.Import | ast.ImportFrom = dataclasses.field(init=False, repr=False, compare=False) + is_type_checking: bool = dataclasses.field(init=False, repr=False, compare=False, default=False) def __len__(self) -> int: return len(self.name.split(".")) @@ -30,7 +31,10 @@ def scope(self): def is_used(self) -> bool: for name in self.scope.names: - if name.match_import: + if self.is_type_checking: + if name.match_2(self): + return True + elif name.match_import: if name.match_import == self: return True elif name.match(self): @@ -43,7 +47,11 @@ def match_nearest_duplicate_import(self, name: Name) -> bool: scope = name.scope while scope: - imports = [_import for _import in scope.imports if name.match_2(_import) and name.lineno > _import.lineno] + imports = [ + _import + for _import in scope.imports + if name.match_2(_import) and name.lineno > _import.lineno and not _import.is_type_checking + ] scope = scope.parent if imports: @@ -62,7 +70,7 @@ def match_nearest_duplicate_import(self, name: Name) -> bool: @property def is_duplicate(self) -> bool: - return [_import.name for _import in self.imports].count(self.name) > 1 + return [_import.name for _import in self.imports if not _import.is_type_checking].count(self.name) > 1 @classmethod def get_unused_imports(cls, *, include_star_import: bool = False) -> typing.Iterator[Import | ImportFrom]: @@ -73,9 +81,12 @@ def get_unused_imports(cls, *, include_star_import: bool = False) -> typing.Iter yield imp @classmethod - def register(cls, *, lineno: int, column: int, name: str, package: str, node: ast.Import) -> None: + def register( + cls, *, lineno: int, column: int, name: str, package: str, node: ast.Import, is_type_checking: bool = False + ) -> None: _import = cls(lineno, column, name, package) _import.node = node + _import.is_type_checking = is_type_checking cls.imports.append(_import) Scope.register(_import) @@ -104,9 +115,11 @@ def register( # type: ignore[override] # noqa star: bool, suggestions: list[str], node: ast.ImportFrom, + is_type_checking: bool = False, ) -> None: _import = cls(lineno, column, name, package, star, suggestions) _import.node = node + _import.is_type_checking = is_type_checking cls.imports.append(_import) Scope.register(_import) diff --git a/tests/cases/analyzer/type_variable/type_checking_duplicate_import.py b/tests/cases/analyzer/type_variable/type_checking_duplicate_import.py new file mode 100644 index 00000000..448edbff --- /dev/null +++ b/tests/cases/analyzer/type_variable/type_checking_duplicate_import.py @@ -0,0 +1,36 @@ +from typing import Union + +from unimport.statement import Import, ImportFrom, Name + +__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"] + + +NAMES: list[Name] = [ + Name(lineno=4, name="t.TYPE_CHECKING", is_all=False), + Name(lineno=7, name="QtCore.QThread", is_all=False), +] +IMPORTS: list[Union[Import, ImportFrom]] = [ + ImportFrom( + lineno=1, + column=1, + name="QtCore", + package="qtpy", + star=False, + suggestions=[], + ), + Import( + lineno=2, + column=1, + name="t", + package="typing", + ), + ImportFrom( + lineno=5, + column=1, + name="QtCore", + package="PySide6", + star=False, + suggestions=[], + ), +] +UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [] diff --git a/tests/cases/analyzer/type_variable/type_checking_unused_import.py b/tests/cases/analyzer/type_variable/type_checking_unused_import.py new file mode 100644 index 00000000..7f6661c2 --- /dev/null +++ b/tests/cases/analyzer/type_variable/type_checking_unused_import.py @@ -0,0 +1,44 @@ +from typing import Union + +from unimport.statement import Import, ImportFrom, Name + +__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"] + + +NAMES: list[Name] = [] +IMPORTS: list[Union[Import, ImportFrom]] = [ + ImportFrom( + lineno=1, + column=1, + name="TYPE_CHECKING", + package="typing", + star=False, + suggestions=[], + ), + ImportFrom( + lineno=4, + column=1, + name="QtWebKit", + package="PyQt5", + star=False, + suggestions=[], + ), +] +UNUSED_IMPORTS: list[Union[Import, ImportFrom]] = [ + ImportFrom( + lineno=4, + column=1, + name="QtWebKit", + package="PyQt5", + star=False, + suggestions=[], + ), + ImportFrom( + lineno=1, + column=1, + name="TYPE_CHECKING", + package="typing", + star=False, + suggestions=[], + ), +] diff --git a/tests/cases/refactor/type_variable/type_checking_duplicate_import.py b/tests/cases/refactor/type_variable/type_checking_duplicate_import.py new file mode 100644 index 00000000..88b2a165 --- /dev/null +++ b/tests/cases/refactor/type_variable/type_checking_duplicate_import.py @@ -0,0 +1,8 @@ +from qtpy import QtCore +import typing as t + +if t.TYPE_CHECKING: + from PySide6 import QtCore + +class MyThread(QtCore.QThread): + pass diff --git a/tests/cases/refactor/type_variable/type_checking_unused_import.py b/tests/cases/refactor/type_variable/type_checking_unused_import.py new file mode 100644 index 00000000..e6e9b09c --- /dev/null +++ b/tests/cases/refactor/type_variable/type_checking_unused_import.py @@ -0,0 +1,3 @@ + +class Foo: + pass diff --git a/tests/cases/source/type_variable/type_checking_duplicate_import.py b/tests/cases/source/type_variable/type_checking_duplicate_import.py new file mode 100644 index 00000000..88b2a165 --- /dev/null +++ b/tests/cases/source/type_variable/type_checking_duplicate_import.py @@ -0,0 +1,8 @@ +from qtpy import QtCore +import typing as t + +if t.TYPE_CHECKING: + from PySide6 import QtCore + +class MyThread(QtCore.QThread): + pass diff --git a/tests/cases/source/type_variable/type_checking_unused_import.py b/tests/cases/source/type_variable/type_checking_unused_import.py new file mode 100644 index 00000000..3f133430 --- /dev/null +++ b/tests/cases/source/type_variable/type_checking_unused_import.py @@ -0,0 +1,7 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from PyQt5 import QtWebKit + +class Foo: + pass