Skip to content

Commit 126abba

Browse files
committed
Try more principled approach
1 parent 78883da commit 126abba

File tree

5 files changed

+128
-129
lines changed

5 files changed

+128
-129
lines changed

mypy/checker.py

Lines changed: 88 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
1515
from mypy.checkmember import (
1616
MemberContext,
17-
analyze_descriptor_access,
17+
analyze_class_attribute_access,
1818
analyze_instance_member_access,
1919
analyze_member_access,
2020
)
@@ -3261,16 +3261,6 @@ def check_assignment(
32613261
if active_class and dataclasses_plugin.is_processed_dataclass(active_class):
32623262
self.fail(message_registry.DATACLASS_POST_INIT_MUST_BE_A_FUNCTION, rvalue)
32633263

3264-
# Defer PartialType's super type checking.
3265-
if (
3266-
isinstance(lvalue, RefExpr)
3267-
and not (isinstance(lvalue_type, PartialType) and lvalue_type.type is None)
3268-
and not (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
3269-
):
3270-
if self.check_compatibility_all_supers(lvalue, lvalue_type, rvalue):
3271-
# We hit an error on this line; don't check for any others
3272-
return
3273-
32743264
if isinstance(lvalue, MemberExpr) and lvalue.name == "__match_args__":
32753265
self.fail(message_registry.CANNOT_MODIFY_MATCH_ARGS, lvalue)
32763266

@@ -3302,12 +3292,6 @@ def check_assignment(
33023292
# Try to infer a partial type. No need to check the return value, as
33033293
# an error will be reported elsewhere.
33043294
self.infer_partial_type(lvalue_type.var, lvalue, rvalue_type)
3305-
# Handle None PartialType's super type checking here, after it's resolved.
3306-
if isinstance(lvalue, RefExpr) and self.check_compatibility_all_supers(
3307-
lvalue, lvalue_type, rvalue
3308-
):
3309-
# We hit an error on this line; don't check for any others
3310-
return
33113295
elif (
33123296
is_literal_none(rvalue)
33133297
and isinstance(lvalue, NameExpr)
@@ -3399,7 +3383,7 @@ def check_assignment(
33993383
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
34003384

34013385
if inferred:
3402-
type_context = self.get_variable_type_context(inferred)
3386+
type_context = self.get_variable_type_context(inferred, rvalue)
34033387
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
34043388
if not (
34053389
inferred.is_final
@@ -3409,15 +3393,26 @@ def check_assignment(
34093393
rvalue_type = remove_instance_last_known_values(rvalue_type)
34103394
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
34113395
self.check_assignment_to_slots(lvalue)
3396+
if isinstance(lvalue, RefExpr) and not (
3397+
isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__"
3398+
):
3399+
# We check override here at the end after storing the inferred type, since
3400+
# override check will try to access the current attribute via symbol tables
3401+
# (like a regular attribute access).
3402+
self.check_compatibility_all_supers(lvalue, rvalue)
34123403

34133404
# (type, operator) tuples for augmented assignments supported with partial types
34143405
partial_type_augmented_ops: Final = {("builtins.list", "+"), ("builtins.set", "|")}
34153406

3416-
def get_variable_type_context(self, inferred: Var) -> Type | None:
3407+
def get_variable_type_context(self, inferred: Var, rvalue: Expression) -> Type | None:
34173408
type_contexts = []
34183409
if inferred.info:
34193410
for base in inferred.info.mro[1:]:
3420-
base_type, base_node = self.lvalue_type_from_base(inferred, base)
3411+
# For inference within class body, get supertype attribute as it would look on
3412+
# a class object for lambdas overriding methods.
3413+
base_type, base_node = self.lvalue_type_from_base(
3414+
inferred, base, is_class=isinstance(rvalue, LambdaExpr)
3415+
)
34213416
if (
34223417
base_type
34233418
and not (isinstance(base_node, Var) and base_node.invalid_partial_type)
@@ -3484,15 +3479,21 @@ def try_infer_partial_generic_type_from_assignment(
34843479
var.type = fill_typevars_with_any(typ.type)
34853480
del partial_types[var]
34863481

3487-
def check_compatibility_all_supers(
3488-
self, lvalue: RefExpr, lvalue_type: Type | None, rvalue: Expression
3489-
) -> bool:
3482+
def check_compatibility_all_supers(self, lvalue: RefExpr, rvalue: Expression) -> None:
34903483
lvalue_node = lvalue.node
34913484
# Check if we are a class variable with at least one base class
34923485
if (
34933486
isinstance(lvalue_node, Var)
3494-
and lvalue.kind in (MDEF, None)
3495-
and len(lvalue_node.info.bases) > 0 # None for Vars defined via self
3487+
# If we have explicit annotation, there is no point in checking the override
3488+
# for each assignment, so we check only for the first one.
3489+
# TODO: for some reason annotated attributes on self are stored as inferred vars.
3490+
and (
3491+
lvalue_node.line == lvalue.line
3492+
or lvalue_node.is_inferred
3493+
and not lvalue_node.explicit_self_type
3494+
)
3495+
and lvalue.kind in (MDEF, None) # None for Vars defined via self
3496+
and len(lvalue_node.info.bases) > 0
34963497
):
34973498
for base in lvalue_node.info.mro[1:]:
34983499
tnode = base.names.get(lvalue_node.name)
@@ -3508,6 +3509,21 @@ def check_compatibility_all_supers(
35083509
direct_bases = lvalue_node.info.direct_base_classes()
35093510
last_immediate_base = direct_bases[-1] if direct_bases else None
35103511

3512+
# The historical behavior for inferred vars was to compare rvalue type against
3513+
# the type declared in a superclass. To preserve this behavior, we temporarily
3514+
# store the rvalue type on the variable.
3515+
actual_lvalue_type = None
3516+
if lvalue_node.is_inferred and not lvalue_node.explicit_self_type:
3517+
rvalue_type = self.expr_checker.accept(rvalue, lvalue_node.type)
3518+
actual_lvalue_type = lvalue_node.type
3519+
lvalue_node.type = rvalue_type
3520+
lvalue_type, _ = self.lvalue_type_from_base(lvalue_node, lvalue_node.info)
3521+
if lvalue_node.is_inferred and not lvalue_node.explicit_self_type:
3522+
lvalue_node.type = actual_lvalue_type
3523+
3524+
if not lvalue_type:
3525+
return
3526+
35113527
for base in lvalue_node.info.mro[1:]:
35123528
# The type of "__slots__" and some other attributes usually doesn't need to
35133529
# be compatible with a base class. We'll still check the type of "__slots__"
@@ -3528,7 +3544,6 @@ def check_compatibility_all_supers(
35283544
if base_type:
35293545
assert base_node is not None
35303546
if not self.check_compatibility_super(
3531-
lvalue,
35323547
lvalue_type,
35333548
rvalue,
35343549
base,
@@ -3538,7 +3553,7 @@ def check_compatibility_all_supers(
35383553
):
35393554
# Only show one error per variable; even if other
35403555
# base classes are also incompatible
3541-
return True
3556+
return
35423557
if lvalue_type and custom_setter:
35433558
base_type, _ = self.lvalue_type_from_base(
35443559
lvalue_node, base, setter_type=True
@@ -3550,104 +3565,49 @@ def check_compatibility_all_supers(
35503565
self.msg.incompatible_setter_override(
35513566
lvalue, lvalue_type, base_type, base
35523567
)
3553-
return True
3568+
return
35543569
if base is last_immediate_base:
35553570
# At this point, the attribute was found to be compatible with all
35563571
# immediate parents.
35573572
break
3558-
return False
35593573

35603574
def check_compatibility_super(
35613575
self,
3562-
lvalue: RefExpr,
3563-
lvalue_type: Type | None,
3576+
compare_type: Type,
35643577
rvalue: Expression,
35653578
base: TypeInfo,
35663579
base_type: Type,
35673580
base_node: Node,
35683581
always_allow_covariant: bool,
35693582
) -> bool:
3570-
lvalue_node = lvalue.node
3571-
assert isinstance(lvalue_node, Var)
3572-
3573-
# Do not check whether the rvalue is compatible if the
3574-
# lvalue had a type defined; this is handled by other
3575-
# parts, and all we have to worry about in that case is
3576-
# that lvalue is compatible with the base class.
3577-
compare_node = None
3578-
if lvalue_type:
3579-
compare_type = lvalue_type
3580-
compare_node = lvalue.node
3581-
else:
3582-
compare_type = self.expr_checker.accept(rvalue, base_type)
3583-
if isinstance(rvalue, NameExpr):
3584-
compare_node = rvalue.node
3585-
if isinstance(compare_node, Decorator):
3586-
compare_node = compare_node.func
3587-
3588-
base_type = get_proper_type(base_type)
3589-
compare_type = get_proper_type(compare_type)
3590-
if compare_type:
3591-
# Unlike for base_type, where we use the full analyze_member_access(), we know
3592-
# that subclass node is an assignment, so we use a much simpler logic: just bind
3593-
# self or invoke descriptors. Although this may not cover some niche corner
3594-
# cases, it is faster, and works with current logic where we check overrides
3595-
# before storing inferred variable type.
3596-
if isinstance(compare_type, CallableType):
3597-
compare_static = is_node_static(compare_node)
3598-
# Since this method may be called before storing inferred type,
3599-
# fall back to rvalue to obtain the static method flag.
3600-
if compare_static is None and compare_type.definition:
3601-
compare_static = is_node_static(compare_type.definition)
3602-
# Compare against False, as is_node_static can return None
3603-
if compare_static is False:
3604-
# TODO: handle aliases to class methods (similarly).
3605-
compare_type = bind_self(compare_type, self.scope.active_self_type())
3606-
3607-
elif isinstance(compare_type, Instance):
3608-
self_type = self.scope.active_self_type()
3609-
assert self_type is not None, "Internal error: base lookup outside class"
3610-
mx = MemberContext(
3611-
is_lvalue=False,
3612-
is_super=False,
3613-
is_operator=False,
3614-
original_type=self_type,
3615-
context=lvalue,
3616-
chk=self,
3617-
suppress_errors=True,
3618-
)
3619-
with self.msg.filter_errors():
3620-
compare_type = analyze_descriptor_access(compare_type, mx)
3621-
3622-
# TODO: check __set__() type override for custom descriptors.
3623-
# TODO: for descriptors check also class object access override.
3583+
# TODO: check __set__() type override for custom descriptors.
3584+
# TODO: for descriptors check also class object access override.
3585+
ok = self.check_subtype(
3586+
compare_type,
3587+
base_type,
3588+
rvalue,
3589+
message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT,
3590+
"expression has type",
3591+
f'base class "{base.name}" defined the type as',
3592+
)
3593+
if (
3594+
ok
3595+
and codes.MUTABLE_OVERRIDE in self.options.enabled_error_codes
3596+
and self.is_writable_attribute(base_node)
3597+
and not always_allow_covariant
3598+
):
36243599
ok = self.check_subtype(
3625-
compare_type,
36263600
base_type,
3601+
compare_type,
36273602
rvalue,
3628-
message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT,
3629-
"expression has type",
3603+
message_registry.COVARIANT_OVERRIDE_OF_MUTABLE_ATTRIBUTE,
36303604
f'base class "{base.name}" defined the type as',
3605+
"expression has type",
36313606
)
3632-
if (
3633-
ok
3634-
and codes.MUTABLE_OVERRIDE in self.options.enabled_error_codes
3635-
and self.is_writable_attribute(base_node)
3636-
and not always_allow_covariant
3637-
):
3638-
ok = self.check_subtype(
3639-
base_type,
3640-
compare_type,
3641-
rvalue,
3642-
message_registry.COVARIANT_OVERRIDE_OF_MUTABLE_ATTRIBUTE,
3643-
f'base class "{base.name}" defined the type as',
3644-
"expression has type",
3645-
)
3646-
return ok
3647-
return True
3607+
return ok
36483608

36493609
def lvalue_type_from_base(
3650-
self, expr_node: Var, base: TypeInfo, setter_type: bool = False
3610+
self, expr_node: Var, base: TypeInfo, setter_type: bool = False, is_class: bool = False
36513611
) -> tuple[Type | None, SymbolNode | None]:
36523612
"""Find a type for a variable name in base class.
36533613
@@ -3661,10 +3621,15 @@ def lvalue_type_from_base(
36613621
base_var = base.names.get(expr_name)
36623622

36633623
# TODO: defer current node if the superclass node is not ready.
3664-
if not base_var or not base_var.type:
3624+
if (
3625+
not base_var
3626+
or not base_var.type
3627+
or isinstance(base_var.type, PartialType)
3628+
and base_var.type.type is not None
3629+
):
36653630
return None, None
36663631

3667-
self_type = self.scope.active_self_type()
3632+
self_type = self.scope.current_self_type()
36683633
assert self_type is not None, "Internal error: base lookup outside class"
36693634
if isinstance(self_type, TupleType):
36703635
instance = tuple_fallback(self_type)
@@ -3681,8 +3646,14 @@ def lvalue_type_from_base(
36813646
suppress_errors=True,
36823647
)
36833648
# TODO: we should not filter "cannot determine type" errors here.
3684-
with self.msg.filter_errors():
3685-
base_type = analyze_instance_member_access(expr_name, instance, mx, base)
3649+
with self.msg.filter_errors(filter_deprecated=True):
3650+
if is_class:
3651+
fallback = instance.type.metaclass_type or mx.named_type("builtins.type")
3652+
base_type = analyze_class_attribute_access(
3653+
instance, expr_name, mx, mcs_fallback=fallback, override_info=base
3654+
)
3655+
else:
3656+
base_type = analyze_instance_member_access(expr_name, instance, mx, base)
36863657
return base_type, base_var.node
36873658

36883659
def check_compatibility_classvar_super(
@@ -4522,6 +4493,7 @@ def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
45224493
self.store_type(lvalue, type)
45234494
p_type = get_proper_type(type)
45244495
if isinstance(p_type, CallableType) and is_node_static(p_type.definition):
4496+
# TODO: handle aliases to class methods (similarly).
45254497
var.is_staticmethod = True
45264498

45274499
def set_inference_error_fallback_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
@@ -8685,6 +8657,13 @@ def active_self_type(self) -> Instance | TupleType | None:
86858657
return fill_typevars(info)
86868658
return None
86878659

8660+
def current_self_type(self) -> Instance | TupleType | None:
8661+
"""Same as active_self_type() but handle functions nested in methods."""
8662+
for item in reversed(self.stack):
8663+
if isinstance(item, TypeInfo):
8664+
return fill_typevars(item)
8665+
return None
8666+
86888667
@contextmanager
86898668
def push_function(self, item: FuncItem) -> Iterator[None]:
86908669
self.stack.append(item)

0 commit comments

Comments
 (0)