diff --git a/mypy/indirection.py b/mypy/indirection.py index 00356d7a4ddb..1be33e45ecba 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Iterable, Set +from typing import Iterable import mypy.types as types from mypy.types import TypeVisitor @@ -17,105 +17,118 @@ def extract_module_names(type_name: str | None) -> list[str]: return [] -class TypeIndirectionVisitor(TypeVisitor[Set[str]]): +class TypeIndirectionVisitor(TypeVisitor[None]): """Returns all module references within a particular type.""" def __init__(self) -> None: - self.cache: dict[types.Type, set[str]] = {} + # Module references are collected here + self.modules: set[str] = set() + # User to avoid infinite recursion with recursive type aliases self.seen_aliases: set[types.TypeAliasType] = set() + # Used to avoid redundant work + self.seen_fullnames: set[str] = set() def find_modules(self, typs: Iterable[types.Type]) -> set[str]: - self.seen_aliases.clear() - return self._visit(typs) + self.modules = set() + self.seen_fullnames = set() + self.seen_aliases = set() + self._visit(typs) + return self.modules - def _visit(self, typ_or_typs: types.Type | Iterable[types.Type]) -> set[str]: + def _visit(self, typ_or_typs: types.Type | Iterable[types.Type]) -> None: typs = [typ_or_typs] if isinstance(typ_or_typs, types.Type) else typ_or_typs - output: set[str] = set() for typ in typs: if isinstance(typ, types.TypeAliasType): # Avoid infinite recursion for recursive type aliases. if typ in self.seen_aliases: continue self.seen_aliases.add(typ) - if typ in self.cache: - modules = self.cache[typ] - else: - modules = typ.accept(self) - self.cache[typ] = set(modules) - output.update(modules) - return output + typ.accept(self) - def visit_unbound_type(self, t: types.UnboundType) -> set[str]: - return self._visit(t.args) + def _visit_module_name(self, module_name: str) -> None: + if module_name not in self.modules: + self.modules.update(split_module_names(module_name)) - def visit_any(self, t: types.AnyType) -> set[str]: - return set() + def visit_unbound_type(self, t: types.UnboundType) -> None: + self._visit(t.args) - def visit_none_type(self, t: types.NoneType) -> set[str]: - return set() + def visit_any(self, t: types.AnyType) -> None: + pass - def visit_uninhabited_type(self, t: types.UninhabitedType) -> set[str]: - return set() + def visit_none_type(self, t: types.NoneType) -> None: + pass - def visit_erased_type(self, t: types.ErasedType) -> set[str]: - return set() + def visit_uninhabited_type(self, t: types.UninhabitedType) -> None: + pass - def visit_deleted_type(self, t: types.DeletedType) -> set[str]: - return set() + def visit_erased_type(self, t: types.ErasedType) -> None: + pass - def visit_type_var(self, t: types.TypeVarType) -> set[str]: - return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default) + def visit_deleted_type(self, t: types.DeletedType) -> None: + pass - def visit_param_spec(self, t: types.ParamSpecType) -> set[str]: - return self._visit(t.upper_bound) | self._visit(t.default) + def visit_type_var(self, t: types.TypeVarType) -> None: + self._visit(t.values) + self._visit(t.upper_bound) + self._visit(t.default) - def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]: - return self._visit(t.upper_bound) | self._visit(t.default) + def visit_param_spec(self, t: types.ParamSpecType) -> None: + self._visit(t.upper_bound) + self._visit(t.default) - def visit_unpack_type(self, t: types.UnpackType) -> set[str]: - return t.type.accept(self) + def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> None: + self._visit(t.upper_bound) + self._visit(t.default) - def visit_parameters(self, t: types.Parameters) -> set[str]: - return self._visit(t.arg_types) + def visit_unpack_type(self, t: types.UnpackType) -> None: + t.type.accept(self) - def visit_instance(self, t: types.Instance) -> set[str]: - out = self._visit(t.args) + def visit_parameters(self, t: types.Parameters) -> None: + self._visit(t.arg_types) + + def visit_instance(self, t: types.Instance) -> None: + self._visit(t.args) if t.type: # Uses of a class depend on everything in the MRO, # as changes to classes in the MRO can add types to methods, # change property types, change the MRO itself, etc. for s in t.type.mro: - out.update(split_module_names(s.module_name)) + self._visit_module_name(s.module_name) if t.type.metaclass_type is not None: - out.update(split_module_names(t.type.metaclass_type.type.module_name)) - return out + self._visit_module_name(t.type.metaclass_type.type.module_name) - def visit_callable_type(self, t: types.CallableType) -> set[str]: - out = self._visit(t.arg_types) | self._visit(t.ret_type) + def visit_callable_type(self, t: types.CallableType) -> None: + self._visit(t.arg_types) + self._visit(t.ret_type) if t.definition is not None: - out.update(extract_module_names(t.definition.fullname)) - return out + fullname = t.definition.fullname + if fullname not in self.seen_fullnames: + self.modules.update(extract_module_names(t.definition.fullname)) + self.seen_fullnames.add(fullname) - def visit_overloaded(self, t: types.Overloaded) -> set[str]: - return self._visit(t.items) | self._visit(t.fallback) + def visit_overloaded(self, t: types.Overloaded) -> None: + self._visit(t.items) + self._visit(t.fallback) - def visit_tuple_type(self, t: types.TupleType) -> set[str]: - return self._visit(t.items) | self._visit(t.partial_fallback) + def visit_tuple_type(self, t: types.TupleType) -> None: + self._visit(t.items) + self._visit(t.partial_fallback) - def visit_typeddict_type(self, t: types.TypedDictType) -> set[str]: - return self._visit(t.items.values()) | self._visit(t.fallback) + def visit_typeddict_type(self, t: types.TypedDictType) -> None: + self._visit(t.items.values()) + self._visit(t.fallback) - def visit_literal_type(self, t: types.LiteralType) -> set[str]: - return self._visit(t.fallback) + def visit_literal_type(self, t: types.LiteralType) -> None: + self._visit(t.fallback) - def visit_union_type(self, t: types.UnionType) -> set[str]: - return self._visit(t.items) + def visit_union_type(self, t: types.UnionType) -> None: + self._visit(t.items) - def visit_partial_type(self, t: types.PartialType) -> set[str]: - return set() + def visit_partial_type(self, t: types.PartialType) -> None: + pass - def visit_type_type(self, t: types.TypeType) -> set[str]: - return self._visit(t.item) + def visit_type_type(self, t: types.TypeType) -> None: + self._visit(t.item) - def visit_type_alias_type(self, t: types.TypeAliasType) -> set[str]: - return self._visit(types.get_proper_type(t)) + def visit_type_alias_type(self, t: types.TypeAliasType) -> None: + self._visit(types.get_proper_type(t)) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 0380d1aa82d1..35102be80f5d 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -230,12 +230,14 @@ def test_recursive_nested_in_non_recursive(self) -> None: def test_indirection_no_infinite_recursion(self) -> None: A, _ = self.fx.def_alias_1(self.fx.a) visitor = TypeIndirectionVisitor() - modules = A.accept(visitor) + A.accept(visitor) + modules = visitor.modules assert modules == {"__main__", "builtins"} A, _ = self.fx.def_alias_2(self.fx.a) visitor = TypeIndirectionVisitor() - modules = A.accept(visitor) + A.accept(visitor) + modules = visitor.modules assert modules == {"__main__", "builtins"}