Skip to content

Commit 3a3cdf8

Browse files
committed
Use checkmember.py to check protocol subtyping
1 parent a4e79ea commit 3a3cdf8

File tree

10 files changed

+139
-51
lines changed

10 files changed

+139
-51
lines changed

mypy/checker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def check_first_pass(self) -> None:
455455
Deferred functions will be processed by check_second_pass().
456456
"""
457457
self.recurse_into_functions = True
458-
with state.strict_optional_set(self.options.strict_optional):
458+
with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self):
459459
self.errors.set_file(
460460
self.path, self.tree.fullname, scope=self.tscope, options=self.options
461461
)
@@ -496,7 +496,7 @@ def check_second_pass(
496496
This goes through deferred nodes, returning True if there were any.
497497
"""
498498
self.recurse_into_functions = True
499-
with state.strict_optional_set(self.options.strict_optional):
499+
with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self):
500500
if not todo and not self.deferred_nodes:
501501
return False
502502
self.errors.set_file(

mypy/checkmember.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
is_self: bool = False,
9797
rvalue: Expression | None = None,
9898
suppress_errors: bool = False,
99+
preserve_type_var_ids: bool = False,
99100
) -> None:
100101
self.is_lvalue = is_lvalue
101102
self.is_super = is_super
@@ -112,6 +113,10 @@ def __init__(
112113
assert is_lvalue
113114
self.rvalue = rvalue
114115
self.suppress_errors = suppress_errors
116+
# This attribute is only used to preserve old protocol member access logic.
117+
# It is needed to avoid infinite recursion in cases involving self-referential
118+
# generic methods, see find_member() for details. Do not use for other purposes!
119+
self.preserve_type_var_ids = preserve_type_var_ids
115120

116121
def named_type(self, name: str) -> Instance:
117122
return self.chk.named_type(name)
@@ -142,6 +147,7 @@ def copy_modified(
142147
no_deferral=self.no_deferral,
143148
rvalue=self.rvalue,
144149
suppress_errors=self.suppress_errors,
150+
preserve_type_var_ids=self.preserve_type_var_ids,
145151
)
146152
if self_type is not None:
147153
mx.self_type = self_type
@@ -231,8 +237,6 @@ def analyze_member_access(
231237
def _analyze_member_access(
232238
name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None
233239
) -> Type:
234-
# TODO: This and following functions share some logic with subtypes.find_member;
235-
# consider refactoring.
236240
typ = get_proper_type(typ)
237241
if isinstance(typ, Instance):
238242
return analyze_instance_member_access(name, typ, mx, override_info)
@@ -355,7 +359,8 @@ def analyze_instance_member_access(
355359
return AnyType(TypeOfAny.special_form)
356360
assert isinstance(method.type, Overloaded)
357361
signature = method.type
358-
signature = freshen_all_functions_type_vars(signature)
362+
if not mx.preserve_type_var_ids:
363+
signature = freshen_all_functions_type_vars(signature)
359364
if not method.is_static:
360365
signature = check_self_arg(
361366
signature, mx.self_type, method.is_class, mx.context, name, mx.msg
@@ -928,7 +933,8 @@ def analyze_var(
928933
def expand_without_binding(
929934
typ: Type, var: Var, itype: Instance, original_itype: Instance, mx: MemberContext
930935
) -> Type:
931-
typ = freshen_all_functions_type_vars(typ)
936+
if not mx.preserve_type_var_ids:
937+
typ = freshen_all_functions_type_vars(typ)
932938
typ = expand_self_type_if_needed(typ, mx, var, original_itype)
933939
expanded = expand_type_by_instance(typ, itype)
934940
freeze_all_type_vars(expanded)
@@ -938,7 +944,8 @@ def expand_without_binding(
938944
def expand_and_bind_callable(
939945
functype: FunctionLike, var: Var, itype: Instance, name: str, mx: MemberContext
940946
) -> Type:
941-
functype = freshen_all_functions_type_vars(functype)
947+
if not mx.preserve_type_var_ids:
948+
functype = freshen_all_functions_type_vars(functype)
942949
typ = get_proper_type(expand_self_type(var, functype, mx.original_type))
943950
assert isinstance(typ, FunctionLike)
944951
typ = check_self_arg(typ, mx.self_type, var.is_classmethod, mx.context, name, mx.msg)
@@ -1033,10 +1040,12 @@ def f(self: S) -> T: ...
10331040
return functype
10341041
else:
10351042
selfarg = get_proper_type(item.arg_types[0])
1036-
# This level of erasure matches the one in checker.check_func_def(),
1037-
# better keep these two checks consistent.
1038-
if subtypes.is_subtype(
1043+
# This matches similar special-casing in bind_self(), see more details there.
1044+
self_callable = name == "__call__" and isinstance(selfarg, CallableType)
1045+
if self_callable or subtypes.is_subtype(
10391046
dispatched_arg_type,
1047+
# This level of erasure matches the one in checker.check_func_def(),
1048+
# better keep these two checks consistent.
10401049
erase_typevars(erase_to_bound(selfarg)),
10411050
# This is to work around the fact that erased ParamSpec and TypeVarTuple
10421051
# callables are not always compatible with non-erased ones both ways.
@@ -1197,15 +1206,10 @@ def analyze_class_attribute_access(
11971206
is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or (
11981207
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_class
11991208
)
1200-
is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or (
1201-
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_static
1202-
)
12031209
t = get_proper_type(t)
12041210
if isinstance(t, FunctionLike) and is_classmethod:
12051211
t = check_self_arg(t, mx.self_type, False, mx.context, name, mx.msg)
1206-
result = add_class_tvars(
1207-
t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars
1208-
)
1212+
result = add_class_tvars(t, isuper, is_classmethod, mx, original_vars=original_vars)
12091213
# __set__ is not called on class objects.
12101214
if not mx.is_lvalue:
12111215
result = analyze_descriptor_access(result, mx)
@@ -1337,8 +1341,7 @@ def add_class_tvars(
13371341
t: ProperType,
13381342
isuper: Instance | None,
13391343
is_classmethod: bool,
1340-
is_staticmethod: bool,
1341-
original_type: Type,
1344+
mx: MemberContext,
13421345
original_vars: Sequence[TypeVarLikeType] | None = None,
13431346
) -> Type:
13441347
"""Instantiate type variables during analyze_class_attribute_access,
@@ -1356,9 +1359,6 @@ class B(A[str]): pass
13561359
isuper: Current instance mapped to the superclass where method was defined, this
13571360
is usually done by map_instance_to_supertype()
13581361
is_classmethod: True if this method is decorated with @classmethod
1359-
is_staticmethod: True if this method is decorated with @staticmethod
1360-
original_type: The value of the type B in the expression B.foo() or the corresponding
1361-
component in case of a union (this is used to bind the self-types)
13621362
original_vars: Type variables of the class callable on which the method was accessed
13631363
Returns:
13641364
Expanded method type with added type variables (when needed).
@@ -1379,11 +1379,11 @@ class B(A[str]): pass
13791379
# (i.e. appear in the return type of the class object on which the method was accessed).
13801380
if isinstance(t, CallableType):
13811381
tvars = original_vars if original_vars is not None else []
1382-
t = freshen_all_functions_type_vars(t)
1382+
if not mx.preserve_type_var_ids:
1383+
t = freshen_all_functions_type_vars(t)
13831384
if is_classmethod:
1384-
t = bind_self(t, original_type, is_classmethod=True)
1385-
if is_classmethod or is_staticmethod:
1386-
assert isuper is not None
1385+
t = bind_self(t, mx.self_type, is_classmethod=True)
1386+
if isuper is not None:
13871387
t = expand_type_by_instance(t, isuper)
13881388
freeze_all_type_vars(t)
13891389
return t.copy_modified(variables=list(tvars) + list(t.variables))
@@ -1392,14 +1392,7 @@ class B(A[str]): pass
13921392
[
13931393
cast(
13941394
CallableType,
1395-
add_class_tvars(
1396-
item,
1397-
isuper,
1398-
is_classmethod,
1399-
is_staticmethod,
1400-
original_type,
1401-
original_vars=original_vars,
1402-
),
1395+
add_class_tvars(item, isuper, is_classmethod, mx, original_vars=original_vars),
14031396
)
14041397
for item in t.items
14051398
]

mypy/expandtype.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Final, TypeVar, cast, overload
55

66
from mypy.nodes import ARG_STAR, FakeInfo, Var
7-
from mypy.state import state
87
from mypy.types import (
98
ANY_STRATEGY,
109
AnyType,
@@ -544,6 +543,8 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]:
544543
* Remove everything else if there is an `object`
545544
* Remove strict duplicate types
546545
"""
546+
from mypy.state import state
547+
547548
removed_none = False
548549
new_types = []
549550
all_types = set()

mypy/messages.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,8 +2220,13 @@ def report_protocol_problems(
22202220
exp = get_proper_type(exp)
22212221
got = get_proper_type(got)
22222222
setter_suffix = " setter type" if is_lvalue else ""
2223-
if not isinstance(exp, (CallableType, Overloaded)) or not isinstance(
2224-
got, (CallableType, Overloaded)
2223+
if (
2224+
not isinstance(exp, (CallableType, Overloaded))
2225+
or not isinstance(got, (CallableType, Overloaded))
2226+
# If expected type is a type object, it means it is a nested class.
2227+
# Showing constructor signature in errors would be confusing in this case,
2228+
# since we don't check the signature, only subclassing of type objects.
2229+
or exp.is_type_obj()
22252230
):
22262231
self.note(
22272232
"{}: expected{} {}, got {}".format(

mypy/plugin.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,13 @@ class C: pass
119119
from __future__ import annotations
120120

121121
from abc import abstractmethod
122-
from typing import Any, Callable, NamedTuple, TypeVar
122+
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar
123123

124124
from mypy_extensions import mypyc_attr, trait
125125

126126
from mypy.errorcodes import ErrorCode
127127
from mypy.lookup import lookup_fully_qualified
128128
from mypy.message_registry import ErrorMessage
129-
from mypy.messages import MessageBuilder
130129
from mypy.nodes import (
131130
ArgKind,
132131
CallExpr,
@@ -138,7 +137,6 @@ class C: pass
138137
TypeInfo,
139138
)
140139
from mypy.options import Options
141-
from mypy.tvar_scope import TypeVarLikeScope
142140
from mypy.types import (
143141
CallableType,
144142
FunctionLike,
@@ -149,6 +147,10 @@ class C: pass
149147
UnboundType,
150148
)
151149

150+
if TYPE_CHECKING:
151+
from mypy.messages import MessageBuilder
152+
from mypy.tvar_scope import TypeVarLikeScope
153+
152154

153155
@trait
154156
class TypeAnalyzerPluginInterface:

mypy/state.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
from contextlib import contextmanager
55
from typing import Final
66

7+
from mypy.checker_shared import TypeCheckerSharedApi
8+
79
# These are global mutable state. Don't add anything here unless there's a very
810
# good reason.
911

1012

11-
class StrictOptionalState:
13+
class SubtypeState:
1214
# Wrap this in a class since it's faster that using a module-level attribute.
1315

14-
def __init__(self, strict_optional: bool) -> None:
15-
# Value varies by file being processed
16+
def __init__(self, strict_optional: bool, type_checker: TypeCheckerSharedApi | None) -> None:
17+
# Values vary by file being processed
1618
self.strict_optional = strict_optional
19+
self.type_checker = type_checker
1720

1821
@contextmanager
1922
def strict_optional_set(self, value: bool) -> Iterator[None]:
@@ -24,6 +27,15 @@ def strict_optional_set(self, value: bool) -> Iterator[None]:
2427
finally:
2528
self.strict_optional = saved
2629

30+
@contextmanager
31+
def type_checker_set(self, value: TypeCheckerSharedApi) -> Iterator[None]:
32+
saved = self.type_checker
33+
self.type_checker = value
34+
try:
35+
yield
36+
finally:
37+
self.type_checker = saved
38+
2739

28-
state: Final = StrictOptionalState(strict_optional=True)
40+
state: Final = SubtypeState(strict_optional=True, type_checker=None)
2941
find_occurrences: tuple[str, str] | None = None

mypy/subtypes.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
COVARIANT,
2727
INVARIANT,
2828
VARIANCE_NOT_READY,
29+
Context,
2930
Decorator,
3031
FuncBase,
3132
OverloadedFuncDef,
@@ -717,8 +718,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
717718
elif isinstance(right, Instance):
718719
if right.type.is_protocol and "__call__" in right.type.protocol_members:
719720
# OK, a callable can implement a protocol with a `__call__` member.
720-
# TODO: we should probably explicitly exclude self-types in this case.
721-
call = find_member("__call__", right, left, is_operator=True)
721+
call = find_member("__call__", right, right, is_operator=True)
722722
assert call is not None
723723
if self._is_subtype(left, call):
724724
if len(right.type.protocol_members) == 1:
@@ -954,7 +954,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
954954
if isinstance(right, Instance):
955955
if right.type.is_protocol and "__call__" in right.type.protocol_members:
956956
# same as for CallableType
957-
call = find_member("__call__", right, left, is_operator=True)
957+
call = find_member("__call__", right, right, is_operator=True)
958958
assert call is not None
959959
if self._is_subtype(left, call):
960960
if len(right.type.protocol_members) == 1:
@@ -1261,14 +1261,87 @@ def find_member(
12611261
is_operator: bool = False,
12621262
class_obj: bool = False,
12631263
is_lvalue: bool = False,
1264+
) -> Type | None:
1265+
type_checker = state.type_checker
1266+
if type_checker is None:
1267+
# Unfortunately, there are many scenarios where someone calls is_subtype() before
1268+
# type checking phase. In this case we fallback to old (incomplete) logic.
1269+
# TODO: reduce number of such cases (e.g. semanal_typeargs, post-semanal plugins).
1270+
return find_member_simple(
1271+
name, itype, subtype, is_operator=is_operator, class_obj=class_obj, is_lvalue=is_lvalue
1272+
)
1273+
1274+
# We don't use ATTR_DEFINED error code below (since missing attributes can cause various
1275+
# other error codes), instead we perform quick node lookup with all the fallbacks.
1276+
info = itype.type
1277+
sym = info.get(name)
1278+
node = sym.node if sym else None
1279+
if not node:
1280+
name_not_found = True
1281+
if (
1282+
name not in ["__getattr__", "__setattr__", "__getattribute__"]
1283+
and not is_operator
1284+
and not class_obj
1285+
and itype.extra_attrs is None # skip ModuleType.__getattr__
1286+
):
1287+
for method_name in ("__getattribute__", "__getattr__"):
1288+
method = info.get_method(method_name)
1289+
if method and method.info.fullname != "builtins.object":
1290+
name_not_found = False
1291+
break
1292+
if name_not_found:
1293+
if info.fallback_to_any or class_obj and info.meta_fallback_to_any:
1294+
return AnyType(TypeOfAny.special_form)
1295+
if itype.extra_attrs and name in itype.extra_attrs.attrs:
1296+
return itype.extra_attrs.attrs[name]
1297+
return None
1298+
1299+
from mypy.checkmember import (
1300+
MemberContext,
1301+
analyze_class_attribute_access,
1302+
analyze_instance_member_access,
1303+
)
1304+
1305+
mx = MemberContext(
1306+
is_lvalue=is_lvalue,
1307+
is_super=False,
1308+
is_operator=is_operator,
1309+
original_type=itype,
1310+
self_type=subtype,
1311+
context=Context(), # all errors are filtered, but this is a required argument
1312+
chk=type_checker,
1313+
suppress_errors=True,
1314+
# This is needed to avoid infinite recursion in situations involving protocols like
1315+
# class P(Protocol[T]):
1316+
# def combine(self, other: P[S]) -> P[Tuple[T, S]]: ...
1317+
# Normally we call freshen_all_functions_type_vars() during attribute access,
1318+
# to avoid type variable id collisions, but for protocols this means we can't
1319+
# use the assumption stack, that will grow indefinitely.
1320+
# TODO: find a cleaner solution that doesn't involve massive perf impact.
1321+
preserve_type_var_ids=True,
1322+
)
1323+
with type_checker.msg.filter_errors(filter_deprecated=True):
1324+
if class_obj:
1325+
fallback = itype.type.metaclass_type or mx.named_type("builtins.type")
1326+
return analyze_class_attribute_access(itype, name, mx, mcs_fallback=fallback)
1327+
else:
1328+
return analyze_instance_member_access(name, itype, mx, info)
1329+
1330+
1331+
def find_member_simple(
1332+
name: str,
1333+
itype: Instance,
1334+
subtype: Type,
1335+
*,
1336+
is_operator: bool = False,
1337+
class_obj: bool = False,
1338+
is_lvalue: bool = False,
12641339
) -> Type | None:
12651340
"""Find the type of member by 'name' in 'itype's TypeInfo.
12661341
12671342
Find the member type after applying type arguments from 'itype', and binding
12681343
'self' to 'subtype'. Return None if member was not found.
12691344
"""
1270-
# TODO: this code shares some logic with checkmember.analyze_member_access,
1271-
# consider refactoring.
12721345
info = itype.type
12731346
method = info.get_method(name)
12741347
if method:

0 commit comments

Comments
 (0)