Skip to content
Draft
23 changes: 23 additions & 0 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParamSpecType,
PartialType,
ProperType,
TupleType,
Type,
TypeAliasType,
TypeVarId,
Expand All @@ -27,10 +28,20 @@
UninhabitedType,
UnpackType,
get_proper_type,
get_proper_types,
remove_dups,
)


def _is_tuple_any(typ: ProperType) -> bool:
return (
isinstance(typ, Instance)
and typ.type.fullname == "builtins.tuple"
and len(typ.args) == 1
and isinstance(get_proper_type(typ.args[0]), AnyType)
)


def get_target_type(
tvar: TypeVarLikeType,
type: Type,
Expand All @@ -56,6 +67,18 @@ def get_target_type(
# is also a legal value of T.
if all(any(mypy.subtypes.is_same_type(v, v1) for v in values) for v1 in p_type.values):
return type
if _is_tuple_any(p_type) and all(
isinstance(v, TupleType)
or isinstance(v, Instance)
and v.type.fullname == "builtins.tuple"
for v in get_proper_types(values)
):
# tuple[Any, ...] is compatible with any tuple bounds. It is important
# to not select one of the values in cases like numpy arrays shape. Given
# T = TypeVar("T", tuple[()], tuple[int], tuple[int, int])
# and a proposed solution `tuple[Any, ...]`, we do not want to choose
# tuple[()] arbitrarily.
return type
matching = []
for value in values:
if mypy.subtypes.is_subtype(type, value):
Expand Down
60 changes: 48 additions & 12 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2892,17 +2892,35 @@ def infer_overload_return_type(
Assumes all of the given targets have argument counts compatible with the caller.
"""

matches: list[CallableType] = []
return_types: list[Type] = []
inferred_types: list[Type] = []
args_contain_any = any(map(has_any_type, arg_types))
type_maps: list[dict[Expression, Type]] = []

# First do a pass without external context and find all overloads that
# can be possibly matched. If no Any is present among args, bail out early
# on the first match.
candidates = []
for typ in plausible_targets:
assert self.msg is self.chk.msg
with self.msg.filter_errors() as w:
with self.chk.local_type_map as m:
ret_type, infer_type = self.check_call(
with self.msg.filter_errors() as w, self.chk.local_type_map as m:
# Overload selection should not depend on the context.
# During this step pretend that we do not have any external information.
self.type_context.append(None)
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
self.type_context.pop()
is_match = not w.has_new_errors()
if is_match:
# Return early if possible
if not args_contain_any:
# Yes, just again
# FIXME: find a way to avoid doing this
return self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
Expand All @@ -2911,13 +2929,31 @@ def infer_overload_return_type(
callable_name=callable_name,
object_type=object_type,
)
candidates.append(typ)

# Repeat the same with outer context, but only for the select candidates.
matches: list[CallableType] = []
return_types: list[Type] = []
inferred_types: list[Type] = []
type_maps: list[dict[Expression, Type]] = []

for typ in candidates:
assert self.msg is self.chk.msg
with self.msg.filter_errors() as w, self.chk.local_type_map as m:
# Overload selection should not depend on the context.
# During this step pretend that we do not have any external information.
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
is_match = not w.has_new_errors()
if is_match:
# Return early if possible; otherwise record info, so we can
# check for ambiguity due to 'Any' below.
if not args_contain_any:
self.chk.store_types(m)
return ret_type, infer_type
# Record info, so we can check for ambiguity due to 'Any' below.
p_infer_type = get_proper_type(infer_type)
if isinstance(p_infer_type, CallableType):
# Prefer inferred types if possible, this will avoid false triggers for
Expand Down
10 changes: 5 additions & 5 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,21 +571,21 @@ def any_constraints(options: list[list[Constraint] | None], *, eager: bool) -> l
def filter_satisfiable(option: list[Constraint] | None) -> list[Constraint] | None:
"""Keep only constraints that can possibly be satisfied.

Currently, we filter out constraints where target is not a subtype of the upper bound.
Currently, we filter out constraints where target does not overlap with the upper bound.
Since those can be never satisfied. We may add more cases in future if it improves type
inference.
"""
from mypy.meet import is_overlapping_types

if not option:
return option

satisfiable = []
for c in option:
if isinstance(c.origin_type_var, TypeVarType) and c.origin_type_var.values:
if any(
mypy.subtypes.is_subtype(c.target, value) for value in c.origin_type_var.values
):
if any(is_overlapping_types(c.target, value) for value in c.origin_type_var.values):
satisfiable.append(c)
elif mypy.subtypes.is_subtype(c.target, c.origin_type_var.upper_bound):
elif is_overlapping_types(c.target, c.origin_type_var.upper_bound):
satisfiable.append(c)
if not satisfiable:
return None
Expand Down
13 changes: 5 additions & 8 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,16 +576,13 @@ def pre_validate_solutions(
new_solutions.append(s)
continue
if s is not None and not is_subtype(s, t.upper_bound):
bound_satisfies_all = True
bound = t.upper_bound
for c in constraints:
if c.op == SUBTYPE_OF and not is_subtype(t.upper_bound, c.target):
bound_satisfies_all = False
bound = meet_types(bound, c.target)
if isinstance(bound, UninhabitedType):
break
if c.op == SUPERTYPE_OF and not is_subtype(c.target, t.upper_bound):
bound_satisfies_all = False
break
if bound_satisfies_all:
new_solutions.append(t.upper_bound)
else:
new_solutions.append(bound)
continue
new_solutions.append(s)
return new_solutions
Expand Down
6 changes: 3 additions & 3 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -2122,7 +2122,7 @@ TA = TypeVar('TA', bound=A)
TB = TypeVar('TB', bound=B)

def f(b_or_t: Union[TA, TB, int]) -> None:
a2 = replace(b_or_t) # E: Value of type variable "_DataclassT" of "replace" cannot be "Union[TA, TB, int]"
a2 = replace(b_or_t) # E: Argument 1 to "replace" has incompatible type "Union[TA, TB, int]"; expected "TA"

[builtins fixtures/tuple.pyi]

Expand Down Expand Up @@ -2202,9 +2202,9 @@ replace(None) # E: Value of type variable "_DataclassT" of "replace" cannot be
from dataclasses import is_dataclass, replace

def f(x: object) -> None:
_ = replace(x) # E: Value of type variable "_DataclassT" of "replace" cannot be "object"
_ = replace(x) # E: Argument 1 to "replace" has incompatible type "object"; expected "DataclassInstance"
if is_dataclass(x):
_ = replace(x) # E: Value of type variable "_DataclassT" of "replace" cannot be "Union[DataclassInstance, type[DataclassInstance]]"
_ = replace(x) # E: Argument 1 to "replace" has incompatible type "Union[DataclassInstance, type[DataclassInstance]]"; expected "DataclassInstance"
if not isinstance(x, type):
_ = replace(x)

Expand Down
16 changes: 8 additions & 8 deletions test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -3061,10 +3061,10 @@ main:15: error: Unsupported left operand type for >= ("NoCmp")
[case testAttrsIncrementalDunder]
from a import A
reveal_type(A) # N: Revealed type is "def (a: builtins.int) -> a.A"
reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool"
reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool"

A(1) < A(2)
A(1) <= A(2)
Expand Down Expand Up @@ -3098,10 +3098,10 @@ class A:
[stale]
[out2]
main:2: note: Revealed type is "def (a: builtins.int) -> a.A"
main:3: note: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool"
main:4: note: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
main:5: note: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
main:6: note: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
main:3: note: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
main:4: note: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
main:5: note: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
main:6: note: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool"
main:15: error: Unsupported operand types for < ("A" and "int")
main:16: error: Unsupported operand types for <= ("A" and "int")
main:17: error: Unsupported operand types for > ("A" and "int")
Expand Down
15 changes: 15 additions & 0 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,21 @@ reveal_type(func2(b)) # N: Revealed type is "Literal[4]"
reveal_type(func2(c)) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]
[out]
main:13: error: Value of type variable "TLiteral" of "func1" cannot be "TInt"
main:20: note: Revealed type is "def [TLiteral <: Literal[3]] (x: TLiteral`-1) -> TLiteral`-1"
main:22: note: Revealed type is "Literal[3]"
main:23: note: Revealed type is "Literal[3]"
main:24: error: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
main:24: note: Revealed type is "Literal[4]"
main:25: error: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
main:25: note: Revealed type is "Literal[4]"
main:26: note: Revealed type is "Literal[3]"
main:26: error: Argument 1 to "func1" has incompatible type "int"; expected "Literal[3]"
main:28: note: Revealed type is "builtins.int"
main:29: note: Revealed type is "Literal[3]"
main:30: note: Revealed type is "builtins.int"
main:31: note: Revealed type is "Literal[4]"
main:32: note: Revealed type is "builtins.int"

[case testLiteralAndGenericsRespectsValueRestriction]
from typing import Literal, TypeVar
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-namedtuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ class NT(NamedTuple, Generic[T]):
return self._replace()

class SNT(NT[int]): ...
reveal_type(SNT("test", 42).meth()) # N: Revealed type is "tuple[builtins.str, builtins.int, fallback=__main__.SNT]"
reveal_type(SNT("test", 42).meth()) # N: Revealed type is "tuple[builtins.str, Never, fallback=__main__.SNT]"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-namedtuple.pyi]

Expand Down
23 changes: 23 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,14 @@ def g(x: U, y: V) -> None:
f([x, y]) # E: Value of type variable "T" of "f" cannot be "object"
[builtins fixtures/list.pyi]
[out]
tmp/foo.pyi:12: error: "mystr" not callable
tmp/foo.pyi:13: error: No overload variant of "f" matches argument type "V"
tmp/foo.pyi:13: note: Possible overload variants:
tmp/foo.pyi:13: note: def [T: str] f(x: T) -> T
tmp/foo.pyi:13: note: def [T: str] f(x: list[T]) -> None
tmp/foo.pyi:15: note: Revealed type is "None"
tmp/foo.pyi:16: error: Value of type variable "T" of "f" cannot be "V"
tmp/foo.pyi:17: error: List item 1 has incompatible type "V"; expected "str"

[case testOverloadOverlapWithTypeVars]
from foo import *
Expand Down Expand Up @@ -6852,3 +6860,18 @@ if isinstance(headers, dict):

reveal_type(headers) # N: Revealed type is "Union[__main__.Headers, typing.Iterable[tuple[builtins.bytes, builtins.bytes]]]"
[builtins fixtures/isinstancelist.pyi]

[case testOverloadSelectionIgnoresContext]
from typing import TypeVar, overload

_T = TypeVar("_T")

@overload # type: ignore[no-overload-impl]
def gather(f1: _T) -> tuple[_T]: ...
@overload
def gather(*fns: object) -> int: ...

def crash() -> None:
foo: str
(foo,) = gather(0) # E: Argument 1 to "gather" has incompatible type "int"; expected "str"
[builtins fixtures/tuple.pyi]
16 changes: 8 additions & 8 deletions test-data/unit/check-plugin-attrs.test
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ from attr import attrib, attrs
class A:
a: int
reveal_type(A) # N: Revealed type is "def (a: builtins.int) -> __main__.A"
reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool"
reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool"

A(1) < A(2)
A(1) <= A(2)
Expand Down Expand Up @@ -990,10 +990,10 @@ class C(A, B): pass
@attr.s
class D(A): pass

reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`29, other: _AT`29) -> builtins.bool"
reveal_type(B.__lt__) # N: Revealed type is "def [_AT] (self: _AT`30, other: _AT`30) -> builtins.bool"
reveal_type(C.__lt__) # N: Revealed type is "def [_AT] (self: _AT`31, other: _AT`31) -> builtins.bool"
reveal_type(D.__lt__) # N: Revealed type is "def [_AT] (self: _AT`32, other: _AT`32) -> builtins.bool"
reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`33, other: _AT`33) -> builtins.bool"
reveal_type(B.__lt__) # N: Revealed type is "def [_AT] (self: _AT`34, other: _AT`34) -> builtins.bool"
reveal_type(C.__lt__) # N: Revealed type is "def [_AT] (self: _AT`35, other: _AT`35) -> builtins.bool"
reveal_type(D.__lt__) # N: Revealed type is "def [_AT] (self: _AT`36, other: _AT`36) -> builtins.bool"

A() < A()
B() < B()
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/pythoneval.test
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ MyDDict(dict)[0]
_program.py:6: error: Argument 1 to "defaultdict" has incompatible type "type[list[_T]]"; expected "Optional[Callable[[], str]]"
_program.py:9: error: Invalid index type "str" for "defaultdict[int, str]"; expected type "int"
_program.py:9: error: Incompatible types in assignment (expression has type "int", target has type "str")
_program.py:19: error: Argument 1 to "tst" has incompatible type "defaultdict[str, list[Never]]"; expected "defaultdict[int, list[Never]]"
_program.py:19: error: Dict entry 0 has incompatible type "str": "list[Never]"; expected "int": "list[Never]"
_program.py:23: error: Invalid index type "str" for "MyDDict[dict[Never, Never]]"; expected type "int"

[case testCollectionsAliases]
Expand Down
Loading