Skip to content

Commit 4b13114

Browse files
committed
Support meet&join of callables with non-trivially corresponding parameters
1 parent 653fc9b commit 4b13114

File tree

8 files changed

+363
-97
lines changed

8 files changed

+363
-97
lines changed

mypy/build.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2322,6 +2322,7 @@ def type_checker(self) -> TypeChecker:
23222322
manager.plugin,
23232323
self.per_line_checking_time_ns,
23242324
)
2325+
type_state.object_type = self._type_checker.named_type("builtins.object")
23252326
return self._type_checker
23262327

23272328
def type_map(self) -> dict[Expression, Type]:

mypy/join.py

Lines changed: 106 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from __future__ import annotations
44

55
from collections.abc import Sequence
6-
from typing import overload
6+
from typing import Callable, overload
77

88
import mypy.typeops
99
from mypy.expandtype import expand_type
1010
from mypy.maptype import map_instance_to_supertype
11-
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY
11+
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY, ArgKind
1212
from mypy.state import state
1313
from mypy.subtypes import (
1414
SubtypeContext,
@@ -52,6 +52,7 @@
5252
get_proper_types,
5353
split_with_prefix_and_suffix,
5454
)
55+
from mypy.typestate import type_state
5556

5657

5758
class InstanceJoiner:
@@ -306,17 +307,7 @@ def visit_unpack_type(self, t: UnpackType) -> UnpackType:
306307

307308
def visit_parameters(self, t: Parameters) -> ProperType:
308309
if isinstance(self.s, Parameters):
309-
if not is_similar_params(t, self.s):
310-
# TODO: it would be prudent to return [*object, **object] instead of Any.
311-
return self.default(self.s)
312-
from mypy.meet import meet_types
313-
314-
return t.copy_modified(
315-
arg_types=[
316-
meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)
317-
],
318-
arg_names=combine_arg_names(self.s, t),
319-
)
310+
return join_parameters(self.s, t) or self.default(self.s)
320311
else:
321312
return self.default(self.s)
322313

@@ -354,10 +345,12 @@ def visit_instance(self, t: Instance) -> ProperType:
354345
return self.default(self.s)
355346

356347
def visit_callable_type(self, t: CallableType) -> ProperType:
357-
if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
348+
if isinstance(self.s, CallableType):
358349
if is_equivalent(t, self.s):
359350
return combine_similar_callables(t, self.s)
360351
result = join_similar_callables(t, self.s)
352+
if result is None:
353+
return join_types(t.fallback, self.s)
361354
# We set the from_type_type flag to suppress error when a collection of
362355
# concrete class objects gets inferred as their common abstract superclass.
363356
if not (
@@ -416,11 +409,12 @@ def visit_overloaded(self, t: Overloaded) -> ProperType:
416409
# The interesting case where both types are function types.
417410
for t_item in t.items:
418411
for s_item in s.items:
419-
if is_similar_callables(t_item, s_item):
420-
if is_equivalent(t_item, s_item):
421-
result.append(combine_similar_callables(t_item, s_item))
422-
elif is_subtype(t_item, s_item):
423-
result.append(s_item)
412+
if is_equivalent(t_item, s_item):
413+
result.append(combine_similar_callables(t_item, s_item))
414+
elif is_subtype(t_item, s_item):
415+
result.append(s_item)
416+
elif (true_join := join_similar_callables(s_item, t_item)) is not None:
417+
result.append(true_join)
424418
if result:
425419
# TODO: Simplify redundancies from the result.
426420
if len(result) == 1:
@@ -638,6 +632,8 @@ def default(self, typ: Type) -> ProperType:
638632
return self.default(typ.upper_bound)
639633
elif isinstance(typ, ParamSpecType):
640634
return self.default(typ.upper_bound)
635+
elif type_state.object_type is not None:
636+
return type_state.object_type
641637
else:
642638
return AnyType(TypeOfAny.special_form)
643639

@@ -665,26 +661,6 @@ def normalize_callables(s: ProperType, t: ProperType) -> tuple[ProperType, Prope
665661
return s, t
666662

667663

668-
def is_similar_callables(t: CallableType, s: CallableType) -> bool:
669-
"""Return True if t and s have identical numbers of
670-
arguments, default arguments and varargs.
671-
"""
672-
return (
673-
len(t.arg_types) == len(s.arg_types)
674-
and t.min_args == s.min_args
675-
and t.is_var_arg == s.is_var_arg
676-
)
677-
678-
679-
def is_similar_params(t: Parameters, s: Parameters) -> bool:
680-
# This matches the logic in is_similar_callables() above.
681-
return (
682-
len(t.arg_types) == len(s.arg_types)
683-
and t.min_args == s.min_args
684-
and (t.var_arg() is not None) == (s.var_arg() is not None)
685-
)
686-
687-
688664
def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType:
689665
tv_map = {}
690666
tvs = []
@@ -712,11 +688,83 @@ def match_generic_callables(t: CallableType, s: CallableType) -> tuple[CallableT
712688
return update_callable_ids(t, new_ids), update_callable_ids(s, new_ids)
713689

714690

715-
def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
691+
def join_parameters(s: Parameters, t: Parameters) -> Parameters | None:
692+
from mypy.meet import meet_types
693+
694+
return combine_parameters_with(s, t, arg_transformer=meet_types, allow_uninhabited=False)
695+
696+
697+
def combine_parameters_with(
698+
s: Parameters,
699+
t: Parameters,
700+
arg_transformer: Callable[[Type, Type], Type],
701+
allow_uninhabited: bool,
702+
) -> Parameters | None:
703+
if sum(k.is_required() for k in s.arg_kinds) < sum(k.is_required() for k in t.arg_kinds):
704+
return join_parameters(t, s)
705+
706+
args_meet = []
707+
arg_names: list[str | None] = []
708+
arg_kinds: list[ArgKind] = []
709+
vararg = None
710+
kwarg = None
711+
if (s_var := s.var_arg()) is not None and (t_var := t.var_arg()) is not None:
712+
vararg = arg_transformer(s_var.typ, t_var.typ)
713+
if isinstance(get_proper_type(vararg), UninhabitedType):
714+
vararg = None
715+
if (s_kw := s.kw_arg()) is not None and (t_kw := t.kw_arg()) is not None:
716+
kwarg = arg_transformer(s_kw.typ, t_kw.typ)
717+
if isinstance(get_proper_type(kwarg), UninhabitedType):
718+
kwarg = None
719+
for s_kind, s_a in zip(s.arg_kinds, s.formal_arguments(include_star_args=True)):
720+
if vararg is not None and s_kind == ArgKind.ARG_STAR:
721+
args_meet.append(vararg)
722+
arg_names.append(None)
723+
arg_kinds.append(s_kind)
724+
vararg = None
725+
continue
726+
if kwarg is not None and s_kind == ArgKind.ARG_STAR2:
727+
args_meet.append(kwarg)
728+
arg_names.append(None)
729+
arg_kinds.append(s_kind)
730+
kwarg = None
731+
continue
732+
if s_kind.is_star():
733+
continue
734+
t_a = t.argument_by_position(s_a.pos) or t.argument_by_name(s_a.name)
735+
if t_a is None:
736+
if s_a.required:
737+
return None
738+
continue
739+
typ = arg_transformer(s_a.typ, t_a.typ)
740+
if not allow_uninhabited and isinstance(get_proper_type(typ), UninhabitedType):
741+
return None
742+
args_meet.append(typ)
743+
arg_names.append(s_a.name if s_a.name == t_a.name else None)
744+
kinds = [ArgKind.ARG_POS, ArgKind.ARG_OPT, ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]
745+
if s_a.pos != t_a.pos or s_a.pos is None or t_a.pos is None:
746+
kinds = [k for k in kinds if not k.is_positional()]
747+
if s_a.name != t_a.name or s_a.name is None or t_a.name is None:
748+
kinds = [k for k in kinds if not k.is_named()]
749+
if s_a.required or t_a.required:
750+
kinds = [k for k in kinds if k.is_required()]
751+
arg_kinds.append(kinds[0])
752+
return t.copy_modified(arg_types=args_meet, arg_names=arg_names, arg_kinds=arg_kinds)
753+
754+
755+
def join_similar_callables(t: CallableType, s: CallableType) -> CallableType | None:
756+
if s.param_spec() != t.param_spec():
757+
return None
758+
716759
t, s = match_generic_callables(t, s)
717-
arg_types: list[Type] = []
718-
for i in range(len(t.arg_types)):
719-
arg_types.append(safe_meet(t.arg_types[i], s.arg_types[i]))
760+
761+
joined_params = join_parameters(
762+
Parameters(t.arg_types, t.arg_kinds, t.arg_names),
763+
Parameters(s.arg_types, s.arg_kinds, s.arg_names),
764+
)
765+
if joined_params is None:
766+
return None
767+
720768
# TODO in combine_similar_callables also applies here (names and kinds; user metaclasses)
721769
# The fallback type can be either 'function', 'type', or some user-provided metaclass.
722770
# The result should always use 'function' as a fallback if either operands are using it.
@@ -725,8 +773,9 @@ def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
725773
else:
726774
fallback = s.fallback
727775
return t.copy_modified(
728-
arg_types=arg_types,
729-
arg_names=combine_arg_names(t, s),
776+
arg_types=joined_params.arg_types,
777+
arg_names=joined_params.arg_names,
778+
arg_kinds=joined_params.arg_kinds,
730779
ret_type=join_types(t.ret_type, s.ret_type),
731780
fallback=fallback,
732781
name=None,
@@ -767,10 +816,15 @@ def safe_meet(t: Type, s: Type) -> Type:
767816

768817
def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
769818
t, s = match_generic_callables(t, s)
770-
arg_types: list[Type] = []
771-
for i in range(len(t.arg_types)):
772-
arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
773-
# TODO kinds and argument names
819+
820+
joined_params = combine_parameters_with(
821+
Parameters(t.arg_types, t.arg_kinds, t.arg_names),
822+
Parameters(s.arg_types, s.arg_kinds, s.arg_names),
823+
arg_transformer=safe_join,
824+
allow_uninhabited=True,
825+
)
826+
assert joined_params is not None
827+
774828
# TODO what should happen if one fallback is 'type' and the other is a user-provided metaclass?
775829
# The fallback type can be either 'function', 'type', or some user-provided metaclass.
776830
# The result should always use 'function' as a fallback if either operands are using it.
@@ -779,45 +833,15 @@ def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
779833
else:
780834
fallback = s.fallback
781835
return t.copy_modified(
782-
arg_types=arg_types,
783-
arg_names=combine_arg_names(t, s),
836+
arg_types=joined_params.arg_types,
837+
arg_names=joined_params.arg_names,
838+
arg_kinds=joined_params.arg_kinds,
784839
ret_type=join_types(t.ret_type, s.ret_type),
785840
fallback=fallback,
786841
name=None,
787842
)
788843

789844

790-
def combine_arg_names(
791-
t: CallableType | Parameters, s: CallableType | Parameters
792-
) -> list[str | None]:
793-
"""Produces a list of argument names compatible with both callables.
794-
795-
For example, suppose 't' and 's' have the following signatures:
796-
797-
- t: (a: int, b: str, X: str) -> None
798-
- s: (a: int, b: str, Y: str) -> None
799-
800-
This function would return ["a", "b", None]. This information
801-
is then used above to compute the join of t and s, which results
802-
in a signature of (a: int, b: str, str) -> None.
803-
804-
Note that the third argument's name is omitted and 't' and 's'
805-
are both valid subtypes of this inferred signature.
806-
807-
Precondition: is_similar_types(t, s) is true.
808-
"""
809-
num_args = len(t.arg_types)
810-
new_names = []
811-
for i in range(num_args):
812-
t_name = t.arg_names[i]
813-
s_name = s.arg_names[i]
814-
if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named():
815-
new_names.append(t_name)
816-
else:
817-
new_names.append(None)
818-
return new_names
819-
820-
821845
def object_from_instance(instance: Instance) -> Instance:
822846
"""Construct the type 'builtins.object' from an instance type."""
823847
# Use the fact that 'object' is always the last class in the mro.

mypy/meet.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,14 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
789789
return self.default(self.s)
790790

791791
def visit_unpack_type(self, t: UnpackType) -> ProperType:
792-
raise NotImplementedError
792+
if isinstance(self.s, UnpackType):
793+
res = UnpackType(
794+
meet_types(self.s.type, t.type), from_star_syntax=self.s.from_star_syntax
795+
)
796+
res.set_line(self.s)
797+
return res
798+
else:
799+
return self.default(self.s)
793800

794801
def visit_parameters(self, t: Parameters) -> ProperType:
795802
if isinstance(self.s, Parameters):
@@ -889,10 +896,12 @@ def visit_instance(self, t: Instance) -> ProperType:
889896
return self.default(self.s)
890897

891898
def visit_callable_type(self, t: CallableType) -> ProperType:
892-
if isinstance(self.s, CallableType) and join.is_similar_callables(t, self.s):
899+
if isinstance(self.s, CallableType):
893900
if is_equivalent(t, self.s):
894901
return join.combine_similar_callables(t, self.s)
895902
result = meet_similar_callables(t, self.s)
903+
if result is None:
904+
return self.default(self.s)
896905
# We set the from_type_type flag to suppress error when a collection of
897906
# concrete class objects gets inferred as their common abstract superclass.
898907
if not (
@@ -1099,13 +1108,21 @@ def default(self, typ: Type) -> ProperType:
10991108
return NoneType()
11001109

11011110

1102-
def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType:
1103-
from mypy.join import match_generic_callables, safe_join
1111+
def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType | None:
1112+
from mypy.join import combine_parameters_with, match_generic_callables, safe_join
1113+
1114+
if s.param_spec() != t.param_spec():
1115+
return None
11041116

11051117
t, s = match_generic_callables(t, s)
1106-
arg_types: list[Type] = []
1107-
for i in range(len(t.arg_types)):
1108-
arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
1118+
joined_params = combine_parameters_with(
1119+
Parameters(t.arg_types, t.arg_kinds, t.arg_names),
1120+
Parameters(s.arg_types, s.arg_kinds, s.arg_names),
1121+
arg_transformer=safe_join,
1122+
allow_uninhabited=True,
1123+
)
1124+
if joined_params is None:
1125+
return None
11091126
# TODO in combine_similar_callables also applies here (names and kinds)
11101127
# The fallback type can be either 'function' or 'type'. The result should have 'function' as
11111128
# fallback only if both operands have it as 'function'.
@@ -1114,7 +1131,9 @@ def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType:
11141131
else:
11151132
fallback = s.fallback
11161133
return t.copy_modified(
1117-
arg_types=arg_types,
1134+
arg_types=joined_params.arg_types,
1135+
arg_names=joined_params.arg_names,
1136+
arg_kinds=joined_params.arg_kinds,
11181137
ret_type=meet_types(t.ret_type, s.ret_type),
11191138
fallback=fallback,
11201139
name=None,

mypy/solve.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from mypy.typeops import get_all_type_vars
1616
from mypy.types import (
1717
AnyType,
18+
CallableType,
1819
Instance,
1920
NoneType,
2021
ParamSpecType,
@@ -576,6 +577,8 @@ def pre_validate_solutions(
576577

577578
def is_callable_protocol(t: Type) -> bool:
578579
proper_t = get_proper_type(t)
580+
if isinstance(proper_t, CallableType) and proper_t.is_ellipsis_args:
581+
return True
579582
if isinstance(proper_t, Instance) and proper_t.type.is_protocol:
580583
return "__call__" in proper_t.type.protocol_members
581584
return False

mypy/typestate.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,7 @@ class TypeState:
9797
# This is temporary and will be removed soon when new algorithm is more polished.
9898
infer_polymorphic: bool
9999

100-
# N.B: We do all of the accesses to these properties through
101-
# TypeState, instead of making these classmethods and accessing
102-
# via the cls parameter, since mypyc can optimize accesses to
103-
# Final attributes of a directly referenced type.
100+
object_type: Instance | None
104101

105102
def __init__(self) -> None:
106103
self._subtype_caches = {}
@@ -114,6 +111,7 @@ def __init__(self) -> None:
114111
self.inferring = []
115112
self.infer_unions = False
116113
self.infer_polymorphic = False
114+
self.object_type = None
117115

118116
def is_assumed_subtype(self, left: Type, right: Type) -> bool:
119117
for l, r in reversed(self._assuming):
@@ -140,6 +138,7 @@ def reset_all_subtype_caches(self) -> None:
140138
"""Completely reset all known subtype caches."""
141139
self._subtype_caches.clear()
142140
self._negative_subtype_caches.clear()
141+
self.object_type = None
143142

144143
def reset_subtype_caches_for(self, info: TypeInfo) -> None:
145144
"""Reset subtype caches (if any) for a given supertype TypeInfo."""

0 commit comments

Comments
 (0)