Skip to content

Commit e067603

Browse files
committed
Use bases with substituted generic args when checking multiple inheritance compatibility
1 parent ec4ccb0 commit e067603

File tree

5 files changed

+151
-56
lines changed

5 files changed

+151
-56
lines changed

mypy/checker.py

Lines changed: 106 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,7 +2104,9 @@ def check_method_override_for_base_with_name(
21042104
original_class_or_static = False # a variable can't be class or static
21052105

21062106
if isinstance(original_type, FunctionLike):
2107-
original_type = self.bind_and_map_method(base_attr, original_type, defn.info, base)
2107+
original_type = self.bind_and_map_method(
2108+
base_attr.node, original_type, defn.info, base
2109+
)
21082110
if original_node and is_property(original_node):
21092111
original_type = get_property_type(original_type)
21102112

@@ -2200,7 +2202,7 @@ def check_method_override_for_base_with_name(
22002202
return False
22012203

22022204
def bind_and_map_method(
2203-
self, sym: SymbolTableNode, typ: FunctionLike, sub_info: TypeInfo, super_info: TypeInfo
2205+
self, node: Node | None, typ: FunctionLike, sub_info: TypeInfo, super_info: TypeInfo
22042206
) -> FunctionLike:
22052207
"""Bind self-type and map type variables for a method.
22062208
@@ -2210,13 +2212,11 @@ def bind_and_map_method(
22102212
sub_info: class where the method is used
22112213
super_info: class where the method was defined
22122214
"""
2213-
if isinstance(sym.node, (FuncDef, OverloadedFuncDef, Decorator)) and not is_static(
2214-
sym.node
2215-
):
2216-
if isinstance(sym.node, Decorator):
2217-
is_class_method = sym.node.func.is_class
2215+
if isinstance(node, (FuncDef, OverloadedFuncDef, Decorator)) and not is_static(node):
2216+
if isinstance(node, Decorator):
2217+
is_class_method = node.func.is_class
22182218
else:
2219-
is_class_method = sym.node.is_class
2219+
is_class_method = node.is_class
22202220

22212221
mapped_typ = cast(FunctionLike, map_type_from_supertype(typ, sub_info, super_info))
22222222
active_self_type = self.scope.active_self_type()
@@ -2745,46 +2745,45 @@ def check_multiple_inheritance(self, typ: TypeInfo) -> None:
27452745
# No multiple inheritance.
27462746
return
27472747
# Verify that inherited attributes are compatible.
2748-
mro = typ.mro[1:]
2749-
for i, base in enumerate(mro):
2748+
bases = typ.bases
2749+
all_names = [{n for p in b.type.mro for n in p.names} for b in bases]
2750+
for i, base in enumerate(bases):
27502751
# Attributes defined in both the type and base are skipped.
27512752
# Normal checks for attribute compatibility should catch any problems elsewhere.
2752-
non_overridden_attrs = base.names.keys() - typ.names.keys()
2753+
non_overridden_attrs = all_names[i] - typ.names.keys()
27532754
for name in non_overridden_attrs:
27542755
if is_private(name):
27552756
continue
2756-
for base2 in mro[i + 1 :]:
2757+
for j, base2 in enumerate(bases[i + 1 :], i + 1):
27572758
# We only need to check compatibility of attributes from classes not
27582759
# in a subclass relationship. For subclasses, normal (single inheritance)
27592760
# checks suffice (these are implemented elsewhere).
2760-
if name in base2.names and base2 not in base.mro:
2761+
if name in all_names[j] and base.type != base2.type:
27612762
self.check_compatibility(name, base, base2, typ)
27622763

2763-
def determine_type_of_member(self, sym: SymbolTableNode) -> Type | None:
2764-
if sym.type is not None:
2765-
return sym.type
2766-
if isinstance(sym.node, FuncBase):
2767-
return self.function_type(sym.node)
2768-
if isinstance(sym.node, TypeInfo):
2769-
if sym.node.typeddict_type:
2764+
def determine_type_of_member(self, node: SymbolNode) -> Type | None:
2765+
if isinstance(node, FuncBase):
2766+
return self.function_type(node)
2767+
if isinstance(node, TypeInfo):
2768+
if node.typeddict_type:
27702769
# We special-case TypedDict, because they don't define any constructor.
2771-
return self.expr_checker.typeddict_callable(sym.node)
2770+
return self.expr_checker.typeddict_callable(node)
27722771
else:
2773-
return type_object_type(sym.node, self.named_type)
2774-
if isinstance(sym.node, TypeVarExpr):
2772+
return type_object_type(node, self.named_type)
2773+
if isinstance(node, TypeVarExpr):
27752774
# Use of TypeVars is rejected in an expression/runtime context, so
27762775
# we don't need to check supertype compatibility for them.
27772776
return AnyType(TypeOfAny.special_form)
2778-
if isinstance(sym.node, TypeAlias):
2777+
if isinstance(node, TypeAlias):
27792778
with self.msg.filter_errors():
27802779
# Suppress any errors, they will be given when analyzing the corresponding node.
27812780
# Here we may have incorrect options and location context.
2782-
return self.expr_checker.alias_type_in_runtime_context(sym.node, ctx=sym.node)
2781+
return self.expr_checker.alias_type_in_runtime_context(node, ctx=node)
27832782
# TODO: handle more node kinds here.
27842783
return None
27852784

27862785
def check_compatibility(
2787-
self, name: str, base1: TypeInfo, base2: TypeInfo, ctx: TypeInfo
2786+
self, name: str, base1: Instance, base2: Instance, ctx: TypeInfo
27882787
) -> None:
27892788
"""Check if attribute name in base1 is compatible with base2 in multiple inheritance.
27902789
@@ -2809,10 +2808,41 @@ class C(B, A[int]): ... # this is unsafe because...
28092808
if name in ("__init__", "__new__", "__init_subclass__"):
28102809
# __init__ and friends can be incompatible -- it's a special case.
28112810
return
2812-
first = base1.names[name]
2813-
second = base2.names[name]
2814-
first_type = get_proper_type(self.determine_type_of_member(first))
2815-
second_type = get_proper_type(self.determine_type_of_member(second))
2811+
first_type = first_node = None
2812+
second_type = second_node = None
2813+
orig_var = ctx.get(name)
2814+
if orig_var is not None and orig_var.node is not None:
2815+
if (b1type := base1.type.get_containing_type_info(name)) is not None:
2816+
base1 = map_instance_to_supertype(base1, b1type)
2817+
first_type, first_node = self.attribute_type_from_base(
2818+
orig_var.node, base1.type, base1
2819+
)
2820+
2821+
if (b2type := base2.type.get_containing_type_info(name)) is not None:
2822+
base2 = map_instance_to_supertype(base2, b2type)
2823+
second_type, second_node = self.attribute_type_from_base(
2824+
orig_var.node, base2.type, base2
2825+
)
2826+
2827+
# Fix the order. We iterate over the explicit bases, which means we may
2828+
# end up with the following structure:
2829+
# class A:
2830+
# def fn(self, x: int) -> None: ...
2831+
# class B(A): ...
2832+
# class C(A):
2833+
# def fn(self, x: int|str) -> None: ...
2834+
# class D(B, C): ...
2835+
# Here D.fn will actually be dispatched to C.fn which is assignable to A.fn,
2836+
# but without this fixup we'd check A.fn against C.fn instead.
2837+
# See testMultipleInheritanceTransitive in check-multiple-inheritance.test
2838+
if (
2839+
b1type is not None
2840+
and b2type is not None
2841+
and ctx.mro.index(b1type) > ctx.mro.index(b2type)
2842+
):
2843+
b1type, b2type = b2type, b1type
2844+
first_type, second_type = second_type, first_type
2845+
first_node, second_node = second_node, first_node
28162846

28172847
# TODO: use more principled logic to decide is_subtype() vs is_equivalent().
28182848
# We should rely on mutability of superclass node, not on types being Callable.
@@ -2822,7 +2852,7 @@ class C(B, A[int]): ... # this is unsafe because...
28222852
if isinstance(first_type, Instance):
28232853
call = find_member("__call__", first_type, first_type, is_operator=True)
28242854
if call and isinstance(second_type, FunctionLike):
2825-
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
2855+
second_sig = self.bind_and_map_method(second_node, second_type, ctx, base2.type)
28262856
ok = is_subtype(call, second_sig, ignore_pos_arg_names=True)
28272857
elif isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike):
28282858
if first_type.is_type_obj() and second_type.is_type_obj():
@@ -2834,42 +2864,70 @@ class C(B, A[int]): ... # this is unsafe because...
28342864
)
28352865
else:
28362866
# First bind/map method types when necessary.
2837-
first_sig = self.bind_and_map_method(first, first_type, ctx, base1)
2838-
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
2867+
first_sig = self.bind_and_map_method(first_node, first_type, ctx, base1.type)
2868+
second_sig = self.bind_and_map_method(second_node, second_type, ctx, base2.type)
28392869
ok = is_subtype(first_sig, second_sig, ignore_pos_arg_names=True)
28402870
elif first_type and second_type:
2841-
if isinstance(first.node, Var):
2842-
first_type = expand_self_type(first.node, first_type, fill_typevars(ctx))
2843-
if isinstance(second.node, Var):
2844-
second_type = expand_self_type(second.node, second_type, fill_typevars(ctx))
2871+
if isinstance(first_node, Var):
2872+
first_type = expand_self_type(first_node, first_type, fill_typevars(ctx))
2873+
if isinstance(second_node, Var):
2874+
second_type = expand_self_type(second_node, second_type, fill_typevars(ctx))
28452875
ok = is_equivalent(first_type, second_type)
28462876
if not ok:
2847-
second_node = base2[name].node
2877+
second_var = base2.type.get(name)
28482878
if (
28492879
isinstance(second_type, FunctionLike)
2850-
and second_node is not None
2851-
and is_property(second_node)
2880+
and second_var is not None
2881+
and second_var.node is not None
2882+
and is_property(second_var.node)
28522883
):
28532884
second_type = get_property_type(second_type)
28542885
ok = is_subtype(first_type, second_type)
28552886
else:
28562887
if first_type is None:
2857-
self.msg.cannot_determine_type_in_base(name, base1.name, ctx)
2888+
self.msg.cannot_determine_type_in_base(name, base1.type.name, ctx)
28582889
if second_type is None:
2859-
self.msg.cannot_determine_type_in_base(name, base2.name, ctx)
2890+
self.msg.cannot_determine_type_in_base(name, base2.type.name, ctx)
28602891
ok = True
28612892
# Final attributes can never be overridden, but can override
28622893
# non-final read-only attributes.
2863-
if is_final_node(second.node) and not is_private(name):
2864-
self.msg.cant_override_final(name, base2.name, ctx)
2865-
if is_final_node(first.node):
2866-
self.check_if_final_var_override_writable(name, second.node, ctx)
2894+
if is_final_node(second_node) and not is_private(name):
2895+
self.msg.cant_override_final(name, base2.type.name, ctx)
2896+
if is_final_node(first_node):
2897+
self.check_if_final_var_override_writable(name, second_node, ctx)
28672898
# Some attributes like __slots__ and __deletable__ are special, and the type can
28682899
# vary across class hierarchy.
2869-
if isinstance(second.node, Var) and second.node.allow_incompatible_override:
2900+
if isinstance(second_node, Var) and second_node.allow_incompatible_override:
28702901
ok = True
28712902
if not ok:
2872-
self.msg.base_class_definitions_incompatible(name, base1, base2, ctx)
2903+
self.msg.base_class_definitions_incompatible(name, base1.type, base2.type, ctx)
2904+
2905+
def attribute_type_from_base(
2906+
self, expr_node: SymbolNode, base: TypeInfo, self_type: Instance
2907+
) -> tuple[ProperType | None, Node | None]:
2908+
"""For a NameExpr that is part of a class, walk all base classes and try
2909+
to find the first class that defines a Type for the same name."""
2910+
expr_name = expr_node.name
2911+
base_var = base.names.get(expr_name)
2912+
2913+
if base_var:
2914+
base_node = base_var.node
2915+
base_type = base_var.type
2916+
2917+
if base_type:
2918+
if not has_no_typevars(base_type):
2919+
itype = map_instance_to_supertype(self_type, base)
2920+
base_type = expand_type_by_instance(base_type, itype)
2921+
2922+
return get_proper_type(base_type), base_node
2923+
2924+
if (
2925+
base_node is not None
2926+
and (base_type := self.determine_type_of_member(base_node)) is not None
2927+
):
2928+
return get_proper_type(base_type), base_node
2929+
2930+
return None, None
28732931

28742932
def check_metaclass_compatibility(self, typ: TypeInfo) -> None:
28752933
"""Ensures that metaclasses of all parent types are compatible."""

mypy/checkexpr.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,12 @@ def module_type(self, node: MypyFile) -> Instance:
462462
continue
463463
if isinstance(n.node, Var) and n.node.is_final:
464464
immutable.add(name)
465-
typ = self.chk.determine_type_of_member(n)
465+
466+
typ = None
467+
if n.type is not None:
468+
typ = n.type
469+
elif n.node is not None:
470+
typ = self.chk.determine_type_of_member(n.node)
466471
if typ:
467472
module_attrs[name] = typ
468473
else:

mypy/nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4202,7 +4202,7 @@ def is_class_var(expr: NameExpr) -> bool:
42024202
return False
42034203

42044204

4205-
def is_final_node(node: SymbolNode | None) -> bool:
4205+
def is_final_node(node: Node | None) -> bool:
42064206
"""Check whether `node` corresponds to a final attribute."""
42074207
return isinstance(node, (Var, FuncDef, OverloadedFuncDef, Decorator)) and node.is_final
42084208

test-data/unit/check-generic-subtyping.test

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ x1: X1[str, int]
946946
reveal_type(list(x1)) # N: Revealed type is "builtins.list[builtins.int]"
947947
reveal_type([*x1]) # N: Revealed type is "builtins.list[builtins.int]"
948948

949-
class X2(Generic[T, U], Iterator[U], Mapping[T, U]):
949+
class X2(Generic[T, U], Iterator[U], Mapping[T, U]): # E: Definition of "__iter__" in base class "Iterable" is incompatible with definition in base class "Iterable"
950950
pass
951951

952952
x2: X2[str, int]
@@ -1017,10 +1017,7 @@ x1: X1[str, int]
10171017
reveal_type(iter(x1)) # N: Revealed type is "typing.Iterator[builtins.int]"
10181018
reveal_type({**x1}) # N: Revealed type is "builtins.dict[builtins.int, builtins.str]"
10191019

1020-
# Some people would expect this to raise an error, but this currently does not:
1021-
# `Mapping` has `Iterable[U]` base class, `X2` has direct `Iterable[T]` base class.
1022-
# It would be impossible to define correct `__iter__` method for incompatible `T` and `U`.
1023-
class X2(Generic[T, U], Mapping[U, T], Iterable[T]):
1020+
class X2(Generic[T, U], Mapping[U, T], Iterable[T]): # E: Definition of "__iter__" in base class "Iterable" is incompatible with definition in base class "Iterable"
10241021
pass
10251022

10261023
x2: X2[str, int]
@@ -1065,3 +1062,31 @@ class F(E[T_co], Generic[T_co]): ... # E: Variance of TypeVar "T_co" incompatib
10651062

10661063
class G(Generic[T]): ...
10671064
class H(G[T_contra], Generic[T_contra]): ... # E: Variance of TypeVar "T_contra" incompatible with variance in parent type
1065+
1066+
[case testMultipleInheritanceCompatibleTypeVar]
1067+
from typing import Generic, TypeVar
1068+
1069+
T = TypeVar("T")
1070+
U = TypeVar("U")
1071+
1072+
class A(Generic[T]):
1073+
x: T
1074+
def fn(self, t: T) -> None: ...
1075+
1076+
class A2(A[T]):
1077+
y: str
1078+
z: str
1079+
1080+
class B(Generic[T]):
1081+
x: T
1082+
def fn(self, t: T) -> None: ...
1083+
1084+
class C1(A2[str], B[str]): pass
1085+
class C2(A2[str], B[int]): pass # E: Definition of "x" in base class "A" is incompatible with definition in base class "B" \
1086+
# E: Definition of "fn" in base class "A" is incompatible with definition in base class "B"
1087+
class C3(A2[T], B[T]): pass
1088+
class C4(A2[U], B[U]): pass
1089+
class C5(A2[U], B[T]): pass # E: Definition of "x" in base class "A" is incompatible with definition in base class "B" \
1090+
# E: Definition of "fn" in base class "A" is incompatible with definition in base class "B"
1091+
1092+
[builtins fixtures/tuple.pyi]

test-data/unit/check-multiple-inheritance.test

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,6 @@ class D2(B[Union[int, str]], C2): ...
669669
class D3(C2, B[str]): ...
670670
class D4(B[str], C2): ... # E: Definition of "foo" in base class "A" is incompatible with definition in base class "C2"
671671

672-
673672
[case testMultipleInheritanceOverridingOfFunctionsWithCallableInstances]
674673
from typing import Any, Callable
675674

@@ -706,3 +705,11 @@ class C34(B3, B4): ...
706705
class C41(B4, B1): ...
707706
class C42(B4, B2): ...
708707
class C43(B4, B3): ...
708+
709+
[case testMultipleInheritanceTransitive]
710+
class A:
711+
def fn(self, x: int) -> None: ...
712+
class B(A): ...
713+
class C(A):
714+
def fn(self, x: "int | str") -> None: ...
715+
class D(B, C): ...

0 commit comments

Comments
 (0)