Skip to content

Commit 8ca882c

Browse files
Merge branch 'master' into patch-3
2 parents f780828 + 68233f6 commit 8ca882c

15 files changed

+340
-59
lines changed

.github/workflows/mypy_primer.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ jobs:
6767
--debug \
6868
--additional-flags="--debug-serialize" \
6969
--output concise \
70+
--show-speed-regression \
7071
| tee diff_${{ matrix.shard-index }}.txt
7172
) || [ $? -eq 1 ]
7273
- if: ${{ matrix.shard-index == 0 }}

mypy/checker.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mypy import errorcodes as codes, join, message_registry, nodes, operators
1414
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
1515
from mypy.checker_shared import CheckerScope, TypeCheckerSharedApi, TypeRange
16+
from mypy.checker_state import checker_state
1617
from mypy.checkmember import (
1718
MemberContext,
1819
analyze_class_attribute_access,
@@ -453,7 +454,7 @@ def check_first_pass(self) -> None:
453454
Deferred functions will be processed by check_second_pass().
454455
"""
455456
self.recurse_into_functions = True
456-
with state.strict_optional_set(self.options.strict_optional):
457+
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
457458
self.errors.set_file(
458459
self.path, self.tree.fullname, scope=self.tscope, options=self.options
459460
)
@@ -494,7 +495,7 @@ def check_second_pass(
494495
This goes through deferred nodes, returning True if there were any.
495496
"""
496497
self.recurse_into_functions = True
497-
with state.strict_optional_set(self.options.strict_optional):
498+
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
498499
if not todo and not self.deferred_nodes:
499500
return False
500501
self.errors.set_file(
@@ -6512,7 +6513,7 @@ def refine_parent_types(self, expr: Expression, expr_type: Type) -> Mapping[Expr
65126513
# and create function that will try replaying the same lookup
65136514
# operation against arbitrary types.
65146515
if isinstance(expr, MemberExpr):
6515-
parent_expr = collapse_walrus(expr.expr)
6516+
parent_expr = self._propagate_walrus_assignments(expr.expr, output)
65166517
parent_type = self.lookup_type_or_none(parent_expr)
65176518
member_name = expr.name
65186519

@@ -6535,9 +6536,10 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None:
65356536
return member_type
65366537

65376538
elif isinstance(expr, IndexExpr):
6538-
parent_expr = collapse_walrus(expr.base)
6539+
parent_expr = self._propagate_walrus_assignments(expr.base, output)
65396540
parent_type = self.lookup_type_or_none(parent_expr)
65406541

6542+
self._propagate_walrus_assignments(expr.index, output)
65416543
index_type = self.lookup_type_or_none(expr.index)
65426544
if index_type is None:
65436545
return output
@@ -6611,6 +6613,24 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None:
66116613
expr = parent_expr
66126614
expr_type = output[parent_expr] = make_simplified_union(new_parent_types)
66136615

6616+
def _propagate_walrus_assignments(
6617+
self, expr: Expression, type_map: dict[Expression, Type]
6618+
) -> Expression:
6619+
"""Add assignments from walrus expressions to inferred types.
6620+
6621+
Only considers nested assignment exprs, does not recurse into other types.
6622+
This may be added later if necessary by implementing a dedicated visitor.
6623+
"""
6624+
if isinstance(expr, AssignmentExpr):
6625+
if isinstance(expr.value, AssignmentExpr):
6626+
self._propagate_walrus_assignments(expr.value, type_map)
6627+
assigned_type = self.lookup_type_or_none(expr.value)
6628+
parent_expr = collapse_walrus(expr)
6629+
if assigned_type is not None:
6630+
type_map[parent_expr] = assigned_type
6631+
return parent_expr
6632+
return expr
6633+
66146634
def refine_identity_comparison_expression(
66156635
self,
66166636
operands: list[Expression],

mypy/checker_state.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterator
4+
from contextlib import contextmanager
5+
from typing import Final
6+
7+
from mypy.checker_shared import TypeCheckerSharedApi
8+
9+
# This is global mutable state. Don't add anything here unless there's a very
10+
# good reason.
11+
12+
13+
class TypeCheckerState:
14+
# Wrap this in a class since it's faster that using a module-level attribute.
15+
16+
def __init__(self, type_checker: TypeCheckerSharedApi | None) -> None:
17+
# Value varies by file being processed
18+
self.type_checker = type_checker
19+
20+
@contextmanager
21+
def set(self, value: TypeCheckerSharedApi) -> Iterator[None]:
22+
saved = self.type_checker
23+
self.type_checker = value
24+
try:
25+
yield
26+
finally:
27+
self.type_checker = saved
28+
29+
30+
checker_state: Final = TypeCheckerState(type_checker=None)

mypy/checkmember.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
is_self: bool = False,
9898
rvalue: Expression | None = None,
9999
suppress_errors: bool = False,
100+
preserve_type_var_ids: bool = False,
100101
) -> None:
101102
self.is_lvalue = is_lvalue
102103
self.is_super = is_super
@@ -113,6 +114,10 @@ def __init__(
113114
assert is_lvalue
114115
self.rvalue = rvalue
115116
self.suppress_errors = suppress_errors
117+
# This attribute is only used to preserve old protocol member access logic.
118+
# It is needed to avoid infinite recursion in cases involving self-referential
119+
# generic methods, see find_member() for details. Do not use for other purposes!
120+
self.preserve_type_var_ids = preserve_type_var_ids
116121

117122
def named_type(self, name: str) -> Instance:
118123
return self.chk.named_type(name)
@@ -143,6 +148,7 @@ def copy_modified(
143148
no_deferral=self.no_deferral,
144149
rvalue=self.rvalue,
145150
suppress_errors=self.suppress_errors,
151+
preserve_type_var_ids=self.preserve_type_var_ids,
146152
)
147153
if self_type is not None:
148154
mx.self_type = self_type
@@ -232,8 +238,6 @@ def analyze_member_access(
232238
def _analyze_member_access(
233239
name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None
234240
) -> Type:
235-
# TODO: This and following functions share some logic with subtypes.find_member;
236-
# consider refactoring.
237241
typ = get_proper_type(typ)
238242
if isinstance(typ, Instance):
239243
return analyze_instance_member_access(name, typ, mx, override_info)
@@ -358,7 +362,8 @@ def analyze_instance_member_access(
358362
return AnyType(TypeOfAny.special_form)
359363
assert isinstance(method.type, Overloaded)
360364
signature = method.type
361-
signature = freshen_all_functions_type_vars(signature)
365+
if not mx.preserve_type_var_ids:
366+
signature = freshen_all_functions_type_vars(signature)
362367
if not method.is_static:
363368
if isinstance(method, (FuncDef, OverloadedFuncDef)) and method.is_trivial_self:
364369
signature = bind_self_fast(signature, mx.self_type)
@@ -943,7 +948,8 @@ def analyze_var(
943948
def expand_without_binding(
944949
typ: Type, var: Var, itype: Instance, original_itype: Instance, mx: MemberContext
945950
) -> Type:
946-
typ = freshen_all_functions_type_vars(typ)
951+
if not mx.preserve_type_var_ids:
952+
typ = freshen_all_functions_type_vars(typ)
947953
typ = expand_self_type_if_needed(typ, mx, var, original_itype)
948954
expanded = expand_type_by_instance(typ, itype)
949955
freeze_all_type_vars(expanded)
@@ -958,7 +964,8 @@ def expand_and_bind_callable(
958964
mx: MemberContext,
959965
is_trivial_self: bool,
960966
) -> Type:
961-
functype = freshen_all_functions_type_vars(functype)
967+
if not mx.preserve_type_var_ids:
968+
functype = freshen_all_functions_type_vars(functype)
962969
typ = get_proper_type(expand_self_type(var, functype, mx.original_type))
963970
assert isinstance(typ, FunctionLike)
964971
if is_trivial_self:
@@ -1056,10 +1063,12 @@ def f(self: S) -> T: ...
10561063
return functype
10571064
else:
10581065
selfarg = get_proper_type(item.arg_types[0])
1059-
# This level of erasure matches the one in checker.check_func_def(),
1060-
# better keep these two checks consistent.
1061-
if subtypes.is_subtype(
1066+
# This matches similar special-casing in bind_self(), see more details there.
1067+
self_callable = name == "__call__" and isinstance(selfarg, CallableType)
1068+
if self_callable or subtypes.is_subtype(
10621069
dispatched_arg_type,
1070+
# This level of erasure matches the one in checker.check_func_def(),
1071+
# better keep these two checks consistent.
10631072
erase_typevars(erase_to_bound(selfarg)),
10641073
# This is to work around the fact that erased ParamSpec and TypeVarTuple
10651074
# callables are not always compatible with non-erased ones both ways.
@@ -1220,9 +1229,6 @@ def analyze_class_attribute_access(
12201229
is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or (
12211230
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_class
12221231
)
1223-
is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or (
1224-
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_static
1225-
)
12261232
t = get_proper_type(t)
12271233
is_trivial_self = False
12281234
if isinstance(node.node, Decorator):
@@ -1236,8 +1242,7 @@ def analyze_class_attribute_access(
12361242
t,
12371243
isuper,
12381244
is_classmethod,
1239-
is_staticmethod,
1240-
mx.self_type,
1245+
mx,
12411246
original_vars=original_vars,
12421247
is_trivial_self=is_trivial_self,
12431248
)
@@ -1372,8 +1377,7 @@ def add_class_tvars(
13721377
t: ProperType,
13731378
isuper: Instance | None,
13741379
is_classmethod: bool,
1375-
is_staticmethod: bool,
1376-
original_type: Type,
1380+
mx: MemberContext,
13771381
original_vars: Sequence[TypeVarLikeType] | None = None,
13781382
is_trivial_self: bool = False,
13791383
) -> Type:
@@ -1392,9 +1396,6 @@ class B(A[str]): pass
13921396
isuper: Current instance mapped to the superclass where method was defined, this
13931397
is usually done by map_instance_to_supertype()
13941398
is_classmethod: True if this method is decorated with @classmethod
1395-
is_staticmethod: True if this method is decorated with @staticmethod
1396-
original_type: The value of the type B in the expression B.foo() or the corresponding
1397-
component in case of a union (this is used to bind the self-types)
13981399
original_vars: Type variables of the class callable on which the method was accessed
13991400
is_trivial_self: if True, we can use fast path for bind_self().
14001401
Returns:
@@ -1416,14 +1417,14 @@ class B(A[str]): pass
14161417
# (i.e. appear in the return type of the class object on which the method was accessed).
14171418
if isinstance(t, CallableType):
14181419
tvars = original_vars if original_vars is not None else []
1419-
t = freshen_all_functions_type_vars(t)
1420+
if not mx.preserve_type_var_ids:
1421+
t = freshen_all_functions_type_vars(t)
14201422
if is_classmethod:
14211423
if is_trivial_self:
1422-
t = bind_self_fast(t, original_type)
1424+
t = bind_self_fast(t, mx.self_type)
14231425
else:
1424-
t = bind_self(t, original_type, is_classmethod=True)
1425-
if is_classmethod or is_staticmethod:
1426-
assert isuper is not None
1426+
t = bind_self(t, mx.self_type, is_classmethod=True)
1427+
if isuper is not None:
14271428
t = expand_type_by_instance(t, isuper)
14281429
freeze_all_type_vars(t)
14291430
return t.copy_modified(variables=list(tvars) + list(t.variables))
@@ -1432,14 +1433,7 @@ class B(A[str]): pass
14321433
[
14331434
cast(
14341435
CallableType,
1435-
add_class_tvars(
1436-
item,
1437-
isuper,
1438-
is_classmethod,
1439-
is_staticmethod,
1440-
original_type,
1441-
original_vars=original_vars,
1442-
),
1436+
add_class_tvars(item, isuper, is_classmethod, mx, original_vars=original_vars),
14431437
)
14441438
for item in t.items
14451439
]

mypy/messages.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,8 +2220,13 @@ def report_protocol_problems(
22202220
exp = get_proper_type(exp)
22212221
got = get_proper_type(got)
22222222
setter_suffix = " setter type" if is_lvalue else ""
2223-
if not isinstance(exp, (CallableType, Overloaded)) or not isinstance(
2224-
got, (CallableType, Overloaded)
2223+
if (
2224+
not isinstance(exp, (CallableType, Overloaded))
2225+
or not isinstance(got, (CallableType, Overloaded))
2226+
# If expected type is a type object, it means it is a nested class.
2227+
# Showing constructor signature in errors would be confusing in this case,
2228+
# since we don't check the signature, only subclassing of type objects.
2229+
or exp.is_type_obj()
22252230
):
22262231
self.note(
22272232
"{}: expected{} {}, got {}".format(

mypy/plugin.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,13 @@ class C: pass
119119
from __future__ import annotations
120120

121121
from abc import abstractmethod
122-
from typing import Any, Callable, NamedTuple, TypeVar
122+
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar
123123

124124
from mypy_extensions import mypyc_attr, trait
125125

126126
from mypy.errorcodes import ErrorCode
127127
from mypy.lookup import lookup_fully_qualified
128128
from mypy.message_registry import ErrorMessage
129-
from mypy.messages import MessageBuilder
130129
from mypy.nodes import (
131130
ArgKind,
132131
CallExpr,
@@ -138,7 +137,6 @@ class C: pass
138137
TypeInfo,
139138
)
140139
from mypy.options import Options
141-
from mypy.tvar_scope import TypeVarLikeScope
142140
from mypy.types import (
143141
CallableType,
144142
FunctionLike,
@@ -149,6 +147,10 @@ class C: pass
149147
UnboundType,
150148
)
151149

150+
if TYPE_CHECKING:
151+
from mypy.messages import MessageBuilder
152+
from mypy.tvar_scope import TypeVarLikeScope
153+
152154

153155
@trait
154156
class TypeAnalyzerPluginInterface:

0 commit comments

Comments
 (0)