diff --git a/mypy/applytype.py b/mypy/applytype.py index dfeaf7752d21..cd0f688fe257 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -18,6 +18,7 @@ ParamSpecType, PartialType, ProperType, + TupleType, Type, TypeAliasType, TypeVarId, @@ -27,10 +28,20 @@ UninhabitedType, UnpackType, get_proper_type, + get_proper_types, remove_dups, ) +def _is_tuple_any(typ: ProperType) -> bool: + return ( + isinstance(typ, Instance) + and typ.type.fullname == "builtins.tuple" + and len(typ.args) == 1 + and isinstance(get_proper_type(typ.args[0]), AnyType) + ) + + def get_target_type( tvar: TypeVarLikeType, type: Type, @@ -56,6 +67,18 @@ def get_target_type( # is also a legal value of T. if all(any(mypy.subtypes.is_same_type(v, v1) for v in values) for v1 in p_type.values): return type + if _is_tuple_any(p_type) and all( + isinstance(v, TupleType) + or isinstance(v, Instance) + and v.type.fullname == "builtins.tuple" + for v in get_proper_types(values) + ): + # tuple[Any, ...] is compatible with any tuple bounds. It is important + # to not select one of the values in cases like numpy arrays shape. Given + # T = TypeVar("T", tuple[()], tuple[int], tuple[int, int]) + # and a proposed solution `tuple[Any, ...]`, we do not want to choose + # tuple[()] arbitrarily. + return type matching = [] for value in values: if mypy.subtypes.is_subtype(type, value): diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3eb54579a050..98d3b5075ed0 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2892,17 +2892,35 @@ def infer_overload_return_type( Assumes all of the given targets have argument counts compatible with the caller. """ - matches: list[CallableType] = [] - return_types: list[Type] = [] - inferred_types: list[Type] = [] args_contain_any = any(map(has_any_type, arg_types)) - type_maps: list[dict[Expression, Type]] = [] + # First do a pass without external context and find all overloads that + # can be possibly matched. If no Any is present among args, bail out early + # on the first match. + candidates = [] for typ in plausible_targets: assert self.msg is self.chk.msg - with self.msg.filter_errors() as w: - with self.chk.local_type_map as m: - ret_type, infer_type = self.check_call( + with self.msg.filter_errors() as w, self.chk.local_type_map as m: + # Overload selection should not depend on the context. + # During this step pretend that we do not have any external information. + self.type_context.append(None) + ret_type, infer_type = self.check_call( + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) + self.type_context.pop() + is_match = not w.has_new_errors() + if is_match: + # Return early if possible + if not args_contain_any: + # Yes, just again + # FIXME: find a way to avoid doing this + return self.check_call( callee=typ, args=args, arg_kinds=arg_kinds, @@ -2911,13 +2929,31 @@ def infer_overload_return_type( callable_name=callable_name, object_type=object_type, ) + candidates.append(typ) + + # Repeat the same with outer context, but only for the select candidates. + matches: list[CallableType] = [] + return_types: list[Type] = [] + inferred_types: list[Type] = [] + type_maps: list[dict[Expression, Type]] = [] + + for typ in candidates: + assert self.msg is self.chk.msg + with self.msg.filter_errors() as w, self.chk.local_type_map as m: + # Overload selection should not depend on the context. + # During this step pretend that we do not have any external information. + ret_type, infer_type = self.check_call( + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) is_match = not w.has_new_errors() if is_match: - # Return early if possible; otherwise record info, so we can - # check for ambiguity due to 'Any' below. - if not args_contain_any: - self.chk.store_types(m) - return ret_type, infer_type + # Record info, so we can check for ambiguity due to 'Any' below. p_infer_type = get_proper_type(infer_type) if isinstance(p_infer_type, CallableType): # Prefer inferred types if possible, this will avoid false triggers for diff --git a/mypy/constraints.py b/mypy/constraints.py index 96c0c7ccaf35..7ecef2c29efd 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -571,21 +571,21 @@ def any_constraints(options: list[list[Constraint] | None], *, eager: bool) -> l def filter_satisfiable(option: list[Constraint] | None) -> list[Constraint] | None: """Keep only constraints that can possibly be satisfied. - Currently, we filter out constraints where target is not a subtype of the upper bound. + Currently, we filter out constraints where target does not overlap with the upper bound. Since those can be never satisfied. We may add more cases in future if it improves type inference. """ + from mypy.meet import is_overlapping_types + if not option: return option satisfiable = [] for c in option: if isinstance(c.origin_type_var, TypeVarType) and c.origin_type_var.values: - if any( - mypy.subtypes.is_subtype(c.target, value) for value in c.origin_type_var.values - ): + if any(is_overlapping_types(c.target, value) for value in c.origin_type_var.values): satisfiable.append(c) - elif mypy.subtypes.is_subtype(c.target, c.origin_type_var.upper_bound): + elif is_overlapping_types(c.target, c.origin_type_var.upper_bound): satisfiable.append(c) if not satisfiable: return None diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index 94f65a950062..bd570415bdc3 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -3061,10 +3061,10 @@ main:15: error: Unsupported left operand type for >= ("NoCmp") [case testAttrsIncrementalDunder] from a import A reveal_type(A) # N: Revealed type is "def (a: builtins.int) -> a.A" -reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool" -reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" -reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" -reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" +reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" +reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool" A(1) < A(2) A(1) <= A(2) @@ -3098,10 +3098,10 @@ class A: [stale] [out2] main:2: note: Revealed type is "def (a: builtins.int) -> a.A" -main:3: note: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool" -main:4: note: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" -main:5: note: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" -main:6: note: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +main:3: note: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" +main:4: note: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" +main:5: note: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +main:6: note: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool" main:15: error: Unsupported operand types for < ("A" and "int") main:16: error: Unsupported operand types for <= ("A" and "int") main:17: error: Unsupported operand types for > ("A" and "int") diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index be55a182b87b..0d581c2ddcb6 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6852,3 +6852,18 @@ if isinstance(headers, dict): reveal_type(headers) # N: Revealed type is "Union[__main__.Headers, typing.Iterable[tuple[builtins.bytes, builtins.bytes]]]" [builtins fixtures/isinstancelist.pyi] + +[case testOverloadSelectionIgnoresContext] +from typing import TypeVar, overload + +_T = TypeVar("_T") + +@overload # type: ignore[no-overload-impl] +def gather(f1: _T) -> tuple[_T]: ... +@overload +def gather(*fns: object) -> int: ... + +def crash() -> None: + foo: str + (foo,) = gather(0) # E: Argument 1 to "gather" has incompatible type "int"; expected "str" +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-plugin-attrs.test b/test-data/unit/check-plugin-attrs.test index 42f21e945ef0..1e07f8be64ac 100644 --- a/test-data/unit/check-plugin-attrs.test +++ b/test-data/unit/check-plugin-attrs.test @@ -185,10 +185,10 @@ from attr import attrib, attrs class A: a: int reveal_type(A) # N: Revealed type is "def (a: builtins.int) -> __main__.A" -reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool" -reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" -reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" -reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" +reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" +reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool" A(1) < A(2) A(1) <= A(2) @@ -990,10 +990,10 @@ class C(A, B): pass @attr.s class D(A): pass -reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`29, other: _AT`29) -> builtins.bool" -reveal_type(B.__lt__) # N: Revealed type is "def [_AT] (self: _AT`30, other: _AT`30) -> builtins.bool" -reveal_type(C.__lt__) # N: Revealed type is "def [_AT] (self: _AT`31, other: _AT`31) -> builtins.bool" -reveal_type(D.__lt__) # N: Revealed type is "def [_AT] (self: _AT`32, other: _AT`32) -> builtins.bool" +reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`33, other: _AT`33) -> builtins.bool" +reveal_type(B.__lt__) # N: Revealed type is "def [_AT] (self: _AT`34, other: _AT`34) -> builtins.bool" +reveal_type(C.__lt__) # N: Revealed type is "def [_AT] (self: _AT`35, other: _AT`35) -> builtins.bool" +reveal_type(D.__lt__) # N: Revealed type is "def [_AT] (self: _AT`36, other: _AT`36) -> builtins.bool" A() < A() B() < B() diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 2069d082df17..fa500036b547 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -825,7 +825,7 @@ MyDDict(dict)[0] _program.py:6: error: Argument 1 to "defaultdict" has incompatible type "type[list[_T]]"; expected "Optional[Callable[[], str]]" _program.py:9: error: Invalid index type "str" for "defaultdict[int, str]"; expected type "int" _program.py:9: error: Incompatible types in assignment (expression has type "int", target has type "str") -_program.py:19: error: Argument 1 to "tst" has incompatible type "defaultdict[str, list[Never]]"; expected "defaultdict[int, list[Never]]" +_program.py:19: error: Dict entry 0 has incompatible type "str": "list[Never]"; expected "int": "list[Never]" _program.py:23: error: Invalid index type "str" for "MyDDict[dict[Never, Never]]"; expected type "int" [case testCollectionsAliases]