diff --git a/mypy/checker.py b/mypy/checker.py index 3b48f66fc3b5..e4b52cfdc6ba 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -12,6 +12,7 @@ import mypy.checkexpr from mypy import errorcodes as codes, join, message_registry, nodes, operators from mypy.binder import ConditionalTypeBinder, Frame, get_declaration +from mypy.checker_shared import CheckerScope, TypeCheckerSharedApi, TypeRange from mypy.checkmember import ( MemberContext, analyze_class_attribute_access, @@ -126,7 +127,7 @@ from mypy.operators import flip_ops, int_op_to_method, neg_ops from mypy.options import PRECISE_TUPLE_TYPES, Options from mypy.patterns import AsPattern, StarredPattern -from mypy.plugin import CheckerPluginInterface, Plugin +from mypy.plugin import Plugin from mypy.plugins import dataclasses as dataclasses_plugin from mypy.scope import Scope from mypy.semanal import is_trivial_body, refers_to_fullname, set_callable_name @@ -259,13 +260,6 @@ class FineGrainedDeferredNode(NamedTuple): TypeMap: _TypeAlias = Optional[dict[Expression, Type]] -# An object that represents either a precise type or a type with an upper bound; -# it is important for correct type inference with isinstance. -class TypeRange(NamedTuple): - item: Type - is_upper_bound: bool # False => precise type - - # Keeps track of partial types in a single scope. In fine-grained incremental # mode partial types initially defined at the top level cannot be completed in # a function, and we use the 'is_function' attribute to enforce this. @@ -275,7 +269,7 @@ class PartialTypeScope(NamedTuple): is_local: bool -class TypeChecker(NodeVisitor[None], CheckerPluginInterface): +class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi): """Mypy type checker. Type check mypy source files that have been semantically analyzed. @@ -302,7 +296,7 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface): # Helper for managing conditional types binder: ConditionalTypeBinder # Helper for type checking expressions - expr_checker: mypy.checkexpr.ExpressionChecker + _expr_checker: mypy.checkexpr.ExpressionChecker pattern_checker: PatternChecker @@ -417,14 +411,18 @@ def __init__( self.allow_abstract_call = False # Child checker objects for specific AST node types - self.expr_checker = mypy.checkexpr.ExpressionChecker( + self._expr_checker = mypy.checkexpr.ExpressionChecker( self, self.msg, self.plugin, per_line_checking_time_ns ) self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options) + @property + def expr_checker(self) -> mypy.checkexpr.ExpressionChecker: + return self._expr_checker + @property def type_context(self) -> list[Type | None]: - return self.expr_checker.type_context + return self._expr_checker.type_context def reset(self) -> None: """Cleanup stale state that might be left over from a typechecking run. @@ -8572,75 +8570,6 @@ def is_node_static(node: Node | None) -> bool | None: return None -class CheckerScope: - # We keep two stacks combined, to maintain the relative order - stack: list[TypeInfo | FuncItem | MypyFile] - - def __init__(self, module: MypyFile) -> None: - self.stack = [module] - - def current_function(self) -> FuncItem | None: - for e in reversed(self.stack): - if isinstance(e, FuncItem): - return e - return None - - def top_level_function(self) -> FuncItem | None: - """Return top-level non-lambda function.""" - for e in self.stack: - if isinstance(e, FuncItem) and not isinstance(e, LambdaExpr): - return e - return None - - def active_class(self) -> TypeInfo | None: - if isinstance(self.stack[-1], TypeInfo): - return self.stack[-1] - return None - - def enclosing_class(self, func: FuncItem | None = None) -> TypeInfo | None: - """Is there a class *directly* enclosing this function?""" - func = func or self.current_function() - assert func, "This method must be called from inside a function" - index = self.stack.index(func) - assert index, "CheckerScope stack must always start with a module" - enclosing = self.stack[index - 1] - if isinstance(enclosing, TypeInfo): - return enclosing - return None - - def active_self_type(self) -> Instance | TupleType | None: - """An instance or tuple type representing the current class. - - This returns None unless we are in class body or in a method. - In particular, inside a function nested in method this returns None. - """ - info = self.active_class() - if not info and self.current_function(): - info = self.enclosing_class() - if info: - return fill_typevars(info) - return None - - def current_self_type(self) -> Instance | TupleType | None: - """Same as active_self_type() but handle functions nested in methods.""" - for item in reversed(self.stack): - if isinstance(item, TypeInfo): - return fill_typevars(item) - return None - - @contextmanager - def push_function(self, item: FuncItem) -> Iterator[None]: - self.stack.append(item) - yield - self.stack.pop() - - @contextmanager - def push_class(self, info: TypeInfo) -> Iterator[None]: - self.stack.append(info) - yield - self.stack.pop() - - TKey = TypeVar("TKey") TValue = TypeVar("TValue") diff --git a/mypy/checker_shared.py b/mypy/checker_shared.py new file mode 100644 index 000000000000..6c62af50466c --- /dev/null +++ b/mypy/checker_shared.py @@ -0,0 +1,349 @@ +"""Shared definitions used by different parts of type checker.""" + +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +from typing import NamedTuple, overload + +from mypy_extensions import trait + +from mypy.errorcodes import ErrorCode +from mypy.errors import ErrorWatcher +from mypy.message_registry import ErrorMessage +from mypy.nodes import ( + ArgKind, + Context, + Expression, + FuncItem, + LambdaExpr, + MypyFile, + Node, + RefExpr, + TypeAlias, + TypeInfo, + Var, +) +from mypy.plugin import CheckerPluginInterface, Plugin +from mypy.types import ( + CallableType, + Instance, + LiteralValue, + Overloaded, + PartialType, + TupleType, + Type, + TypedDictType, + TypeType, +) +from mypy.typevars import fill_typevars + + +# An object that represents either a precise type or a type with an upper bound; +# it is important for correct type inference with isinstance. +class TypeRange(NamedTuple): + item: Type + is_upper_bound: bool # False => precise type + + +@trait +class ExpressionCheckerSharedApi: + @abstractmethod + def accept( + self, + node: Expression, + type_context: Type | None = None, + allow_none_return: bool = False, + always_allow_any: bool = False, + is_callee: bool = False, + ) -> Type: + raise NotImplementedError + + @abstractmethod + def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type: + raise NotImplementedError + + @abstractmethod + def module_type(self, node: MypyFile) -> Instance: + raise NotImplementedError + + @abstractmethod + def check_call( + self, + callee: Type, + args: list[Expression], + arg_kinds: list[ArgKind], + context: Context, + arg_names: Sequence[str | None] | None = None, + callable_node: Expression | None = None, + callable_name: str | None = None, + object_type: Type | None = None, + original_type: Type | None = None, + ) -> tuple[Type, Type]: + raise NotImplementedError + + @abstractmethod + def transform_callee_type( + self, + callable_name: str | None, + callee: Type, + args: list[Expression], + arg_kinds: list[ArgKind], + context: Context, + arg_names: Sequence[str | None] | None = None, + object_type: Type | None = None, + ) -> Type: + raise NotImplementedError + + @abstractmethod + def method_fullname(self, object_type: Type, method_name: str) -> str | None: + raise NotImplementedError + + @abstractmethod + def check_method_call_by_name( + self, + method: str, + base_type: Type, + args: list[Expression], + arg_kinds: list[ArgKind], + context: Context, + original_type: Type | None = None, + ) -> tuple[Type, Type]: + raise NotImplementedError + + @abstractmethod + def alias_type_in_runtime_context( + self, alias: TypeAlias, *, ctx: Context, alias_definition: bool = False + ) -> Type: + raise NotImplementedError + + @abstractmethod + def visit_typeddict_index_expr( + self, td_type: TypedDictType, index: Expression, setitem: bool = False + ) -> tuple[Type, set[str]]: + raise NotImplementedError + + @abstractmethod + def typeddict_callable(self, info: TypeInfo) -> CallableType: + raise NotImplementedError + + @abstractmethod + def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Type: + raise NotImplementedError + + +@trait +class TypeCheckerSharedApi(CheckerPluginInterface): + plugin: Plugin + module_refs: set[str] + scope: CheckerScope + checking_missing_await: bool + + @property + @abstractmethod + def expr_checker(self) -> ExpressionCheckerSharedApi: + raise NotImplementedError + + @abstractmethod + def named_type(self, name: str) -> Instance: + raise NotImplementedError + + @abstractmethod + def lookup_typeinfo(self, fullname: str) -> TypeInfo: + raise NotImplementedError + + @abstractmethod + def lookup_type(self, node: Expression) -> Type: + raise NotImplementedError + + @abstractmethod + def handle_cannot_determine_type(self, name: str, context: Context) -> None: + raise NotImplementedError + + @abstractmethod + def handle_partial_var_type( + self, typ: PartialType, is_lvalue: bool, node: Var, context: Context + ) -> Type: + raise NotImplementedError + + @overload + @abstractmethod + def check_subtype( + self, + subtype: Type, + supertype: Type, + context: Context, + msg: str, + subtype_label: str | None = None, + supertype_label: str | None = None, + *, + notes: list[str] | None = None, + code: ErrorCode | None = None, + outer_context: Context | None = None, + ) -> bool: ... + + @overload + @abstractmethod + def check_subtype( + self, + subtype: Type, + supertype: Type, + context: Context, + msg: ErrorMessage, + subtype_label: str | None = None, + supertype_label: str | None = None, + *, + notes: list[str] | None = None, + outer_context: Context | None = None, + ) -> bool: ... + + # Unfortunately, mypyc doesn't support abstract overloads yet. + @abstractmethod + def check_subtype( + self, + subtype: Type, + supertype: Type, + context: Context, + msg: str | ErrorMessage, + subtype_label: str | None = None, + supertype_label: str | None = None, + *, + notes: list[str] | None = None, + code: ErrorCode | None = None, + outer_context: Context | None = None, + ) -> bool: + raise NotImplementedError + + @abstractmethod + def get_final_context(self) -> bool: + raise NotImplementedError + + @overload + @abstractmethod + def conditional_types_with_intersection( + self, + expr_type: Type, + type_ranges: list[TypeRange] | None, + ctx: Context, + default: None = None, + ) -> tuple[Type | None, Type | None]: ... + + @overload + @abstractmethod + def conditional_types_with_intersection( + self, expr_type: Type, type_ranges: list[TypeRange] | None, ctx: Context, default: Type + ) -> tuple[Type, Type]: ... + + # Unfortunately, mypyc doesn't support abstract overloads yet. + @abstractmethod + def conditional_types_with_intersection( + self, + expr_type: Type, + type_ranges: list[TypeRange] | None, + ctx: Context, + default: Type | None = None, + ) -> tuple[Type | None, Type | None]: + raise NotImplementedError + + @abstractmethod + def check_deprecated(self, node: Node | None, context: Context) -> None: + raise NotImplementedError + + @abstractmethod + def warn_deprecated(self, node: Node | None, context: Context) -> None: + raise NotImplementedError + + @abstractmethod + def warn_deprecated_overload_item( + self, node: Node | None, context: Context, *, target: Type, selftype: Type | None = None + ) -> None: + raise NotImplementedError + + @abstractmethod + def type_is_iterable(self, type: Type) -> bool: + raise NotImplementedError + + @abstractmethod + def iterable_item_type( + self, it: Instance | CallableType | TypeType | Overloaded, context: Context + ) -> Type: + raise NotImplementedError + + @abstractmethod + @contextmanager + def checking_await_set(self) -> Iterator[None]: + raise NotImplementedError + + @abstractmethod + def get_precise_awaitable_type(self, typ: Type, local_errors: ErrorWatcher) -> Type | None: + raise NotImplementedError + + +class CheckerScope: + # We keep two stacks combined, to maintain the relative order + stack: list[TypeInfo | FuncItem | MypyFile] + + def __init__(self, module: MypyFile) -> None: + self.stack = [module] + + def current_function(self) -> FuncItem | None: + for e in reversed(self.stack): + if isinstance(e, FuncItem): + return e + return None + + def top_level_function(self) -> FuncItem | None: + """Return top-level non-lambda function.""" + for e in self.stack: + if isinstance(e, FuncItem) and not isinstance(e, LambdaExpr): + return e + return None + + def active_class(self) -> TypeInfo | None: + if isinstance(self.stack[-1], TypeInfo): + return self.stack[-1] + return None + + def enclosing_class(self, func: FuncItem | None = None) -> TypeInfo | None: + """Is there a class *directly* enclosing this function?""" + func = func or self.current_function() + assert func, "This method must be called from inside a function" + index = self.stack.index(func) + assert index, "CheckerScope stack must always start with a module" + enclosing = self.stack[index - 1] + if isinstance(enclosing, TypeInfo): + return enclosing + return None + + def active_self_type(self) -> Instance | TupleType | None: + """An instance or tuple type representing the current class. + + This returns None unless we are in class body or in a method. + In particular, inside a function nested in method this returns None. + """ + info = self.active_class() + if not info and self.current_function(): + info = self.enclosing_class() + if info: + return fill_typevars(info) + return None + + def current_self_type(self) -> Instance | TupleType | None: + """Same as active_self_type() but handle functions nested in methods.""" + for item in reversed(self.stack): + if isinstance(item, TypeInfo): + return fill_typevars(item) + return None + + @contextmanager + def push_function(self, item: FuncItem) -> Iterator[None]: + self.stack.append(item) + yield + self.stack.pop() + + @contextmanager + def push_class(self, info: TypeInfo) -> Iterator[None]: + self.stack.append(info) + yield + self.stack.pop() diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 12480cf9ab93..099e151dd33d 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -15,6 +15,7 @@ import mypy.errorcodes as codes from mypy import applytype, erasetype, join, message_registry, nodes, operators, types from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals +from mypy.checker_shared import ExpressionCheckerSharedApi from mypy.checkmember import analyze_member_access from mypy.checkstrformat import StringFormatterChecker from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars @@ -296,7 +297,7 @@ class UseReverse(enum.Enum): USE_REVERSE_NEVER: Final = UseReverse.NEVER -class ExpressionChecker(ExpressionVisitor[Type]): +class ExpressionChecker(ExpressionVisitor[Type], ExpressionCheckerSharedApi): """Expression type checker. This class works closely together with checker.TypeChecker. @@ -338,7 +339,7 @@ def __init__( # TODO: refactor this to use a pattern similar to one in # multiassign_from_union, or maybe even combine the two? self.type_overrides: dict[Expression, Type] = {} - self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg) + self.strfrm_checker = StringFormatterChecker(self.chk, self.msg) self.resolved_type = {} diff --git a/mypy/checkmember.py b/mypy/checkmember.py index dfb141aa415c..2152e309b1df 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -3,9 +3,10 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Callable, cast +from typing import Callable, cast -from mypy import message_registry, subtypes +from mypy import message_registry, state, subtypes +from mypy.checker_shared import TypeCheckerSharedApi from mypy.erasetype import erase_typevars from mypy.expandtype import ( expand_self_type, @@ -73,11 +74,6 @@ get_proper_type, ) -if TYPE_CHECKING: # import for forward declaration only - import mypy.checker - -from mypy import state - class MemberContext: """Information and objects needed to type check attribute access. @@ -93,7 +89,7 @@ def __init__( is_operator: bool, original_type: Type, context: Context, - chk: mypy.checker.TypeChecker, + chk: TypeCheckerSharedApi, self_type: Type | None = None, module_symbol_table: SymbolTable | None = None, no_deferral: bool = False, @@ -165,7 +161,7 @@ def analyze_member_access( is_super: bool, is_operator: bool, original_type: Type, - chk: mypy.checker.TypeChecker, + chk: TypeCheckerSharedApi, override_info: TypeInfo | None = None, in_literal_context: bool = False, self_type: Type | None = None, diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index c71d83324694..4cf7c1ca7862 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -5,8 +5,8 @@ from collections import defaultdict from typing import Final, NamedTuple -import mypy.checker from mypy import message_registry +from mypy.checker_shared import TypeCheckerSharedApi, TypeRange from mypy.checkmember import analyze_member_access from mypy.expandtype import expand_type_by_instance from mypy.join import join_types @@ -91,7 +91,7 @@ class PatternChecker(PatternVisitor[PatternType]): """ # Some services are provided by a TypeChecker instance. - chk: mypy.checker.TypeChecker + chk: TypeCheckerSharedApi # This is shared with TypeChecker, but stored also here for convenience. msg: MessageBuilder # Currently unused @@ -112,7 +112,7 @@ class PatternChecker(PatternVisitor[PatternType]): options: Options def __init__( - self, chk: mypy.checker.TypeChecker, msg: MessageBuilder, plugin: Plugin, options: Options + self, chk: TypeCheckerSharedApi, msg: MessageBuilder, plugin: Plugin, options: Options ) -> None: self.chk = chk self.msg = msg @@ -802,7 +802,7 @@ def get_var(expr: Expression) -> Var: return node -def get_type_range(typ: Type) -> mypy.checker.TypeRange: +def get_type_range(typ: Type) -> TypeRange: typ = get_proper_type(typ) if ( isinstance(typ, Instance) @@ -810,7 +810,7 @@ def get_type_range(typ: Type) -> mypy.checker.TypeRange: and isinstance(typ.last_known_value.value, bool) ): typ = typ.last_known_value - return mypy.checker.TypeRange(typ, is_upper_bound=False) + return TypeRange(typ, is_upper_bound=False) def is_uninhabited(typ: Type) -> bool: diff --git a/mypy/checkstrformat.py b/mypy/checkstrformat.py index 289961523b1d..45075bd37552 100644 --- a/mypy/checkstrformat.py +++ b/mypy/checkstrformat.py @@ -14,11 +14,15 @@ import re from re import Match, Pattern -from typing import TYPE_CHECKING, Callable, Final, Union, cast +from typing import Callable, Final, Union, cast from typing_extensions import TypeAlias as _TypeAlias import mypy.errorcodes as codes +from mypy import message_registry +from mypy.checker_shared import TypeCheckerSharedApi from mypy.errors import Errors +from mypy.maptype import map_instance_to_supertype +from mypy.messages import MessageBuilder from mypy.nodes import ( ARG_NAMED, ARG_POS, @@ -41,6 +45,9 @@ TempNode, TupleExpr, ) +from mypy.parse import parse +from mypy.subtypes import is_subtype +from mypy.typeops import custom_special_method from mypy.types import ( AnyType, Instance, @@ -57,18 +64,6 @@ get_proper_types, ) -if TYPE_CHECKING: - # break import cycle only needed for mypy - import mypy.checker - import mypy.checkexpr - -from mypy import message_registry -from mypy.maptype import map_instance_to_supertype -from mypy.messages import MessageBuilder -from mypy.parse import parse -from mypy.subtypes import is_subtype -from mypy.typeops import custom_special_method - FormatStringExpr: _TypeAlias = Union[StrExpr, BytesExpr] Checkers: _TypeAlias = tuple[Callable[[Expression], None], Callable[[Type], bool]] MatchMap: _TypeAlias = dict[tuple[int, int], Match[str]] # span -> match @@ -299,21 +294,13 @@ class StringFormatterChecker: """ # Some services are provided by a TypeChecker instance. - chk: mypy.checker.TypeChecker + chk: TypeCheckerSharedApi # This is shared with TypeChecker, but stored also here for convenience. msg: MessageBuilder - # Some services are provided by a ExpressionChecker instance. - exprchk: mypy.checkexpr.ExpressionChecker - def __init__( - self, - exprchk: mypy.checkexpr.ExpressionChecker, - chk: mypy.checker.TypeChecker, - msg: MessageBuilder, - ) -> None: + def __init__(self, chk: TypeCheckerSharedApi, msg: MessageBuilder) -> None: """Construct an expression type checker.""" self.chk = chk - self.exprchk = exprchk self.msg = msg def check_str_format_call(self, call: CallExpr, format_value: str) -> None: @@ -618,7 +605,7 @@ def apply_field_accessors( # TODO: fix column to point to actual start of the format specifier _within_ string. temp_ast.line = ctx.line temp_ast.column = ctx.column - self.exprchk.accept(temp_ast) + self.chk.expr_checker.accept(temp_ast) return temp_ast def validate_and_transform_accessors( @@ -685,7 +672,7 @@ def check_str_interpolation(self, expr: FormatStringExpr, replacements: Expressi """Check the types of the 'replacements' in a string interpolation expression: str % replacements. """ - self.exprchk.accept(expr) + self.chk.expr_checker.accept(expr) specifiers = parse_conversion_specifiers(expr.value) has_mapping_keys = self.analyze_conversion_specifiers(specifiers, expr) if has_mapping_keys is None: