|
8 | 8 | import mypy.typeops |
9 | 9 | from mypy.expandtype import expand_type |
10 | 10 | from mypy.maptype import map_instance_to_supertype |
11 | | -from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY |
| 11 | +from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY, TypeInfo |
12 | 12 | from mypy.state import state |
13 | 13 | from mypy.subtypes import ( |
14 | 14 | SubtypeContext, |
@@ -168,9 +168,20 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType: |
168 | 168 | # Compute the "best" supertype of t when joined with s. |
169 | 169 | # The definition of "best" may evolve; for now it is the one with |
170 | 170 | # the longest MRO. Ties are broken by using the earlier base. |
171 | | - best: ProperType | None = None |
| 171 | + |
| 172 | + # Go over both sets of bases in case there's an explicit Protocol base. This is important |
| 173 | + # to ensure commutativity of join (although in cases where both classes have relevant |
| 174 | + # Protocol bases this maybe might still not be commutative) |
| 175 | + base_types: dict[TypeInfo, None] = {} # dict to deduplicate but preserve order |
172 | 176 | for base in t.type.bases: |
173 | | - mapped = map_instance_to_supertype(t, base.type) |
| 177 | + base_types[base.type] = None |
| 178 | + for base in s.type.bases: |
| 179 | + if base.type.is_protocol and is_subtype(t, base): |
| 180 | + base_types[base.type] = None |
| 181 | + |
| 182 | + best: ProperType | None = None |
| 183 | + for base_type in base_types: |
| 184 | + mapped = map_instance_to_supertype(t, base_type) |
174 | 185 | res = self.join_instances(mapped, s) |
175 | 186 | if best is None or is_better(res, best): |
176 | 187 | best = res |
@@ -662,6 +673,10 @@ def is_better(t: Type, s: Type) -> bool: |
662 | 673 | if isinstance(t, Instance): |
663 | 674 | if not isinstance(s, Instance): |
664 | 675 | return True |
| 676 | + if t.type.is_protocol != s.type.is_protocol: |
| 677 | + if t.type.fullname != "builtins.object" and s.type.fullname != "builtins.object": |
| 678 | + # mro of protocol is not really relevant |
| 679 | + return not t.type.is_protocol |
665 | 680 | # Use len(mro) as a proxy for the better choice. |
666 | 681 | if len(t.type.mro) > len(s.type.mro): |
667 | 682 | return True |
|
0 commit comments