Skip to content

Commit 5b6a5e4

Browse files
committed
Infer type var tuple contents in more situations
1 parent 4c5b03d commit 5b6a5e4

File tree

3 files changed

+62
-18
lines changed

3 files changed

+62
-18
lines changed

mypy/constraints.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ def infer_constraints_for_callable(
136136
incomplete_star_mapping = True # type: ignore[unreachable]
137137
break
138138

139+
# some constraints are more likely right than others
140+
# so we store them separately and remove unlikely ones later
141+
priority_constraints = []
142+
139143
for i, actuals in enumerate(formal_to_actual):
140144
if isinstance(callee.arg_types[i], UnpackType):
141145
unpack_type = callee.arg_types[i]
@@ -178,7 +182,7 @@ def infer_constraints_for_callable(
178182
)
179183

180184
if isinstance(unpacked_type, TypeVarTupleType):
181-
constraints.append(
185+
priority_constraints.append(
182186
Constraint(
183187
unpacked_type,
184188
SUPERTYPE_OF,
@@ -273,7 +277,10 @@ def infer_constraints_for_callable(
273277
if any(isinstance(v, ParamSpecType) for v in callee.variables):
274278
# As a perf optimization filter imprecise constraints only when we can have them.
275279
constraints = filter_imprecise_kinds(constraints)
276-
return constraints
280+
for pconstraint in priority_constraints:
281+
constraints = [c for c in constraints if c.type_var != pconstraint.type_var]
282+
283+
return constraints + priority_constraints
277284

278285

279286
def infer_constraints(
@@ -1108,6 +1115,13 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
11081115
# (with literal '...').
11091116
if not template.is_ellipsis_args:
11101117
unpack_present = find_unpack_in_list(template.arg_types)
1118+
1119+
# TODO: do we need some special-casing when unpack is present in actual
1120+
# callable but not in template callable?
1121+
res.extend(
1122+
infer_callable_arguments_constraints(template, cactual, self.direction)
1123+
)
1124+
11111125
# When both ParamSpec and TypeVarTuple are present, things become messy
11121126
# quickly. For now, we only allow ParamSpec to "capture" TypeVarTuple,
11131127
# but not vice versa.
@@ -1125,12 +1139,6 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
11251139
template_types, actual_types, neg_op(self.direction)
11261140
)
11271141
res.extend(unpack_constraints)
1128-
else:
1129-
# TODO: do we need some special-casing when unpack is present in actual
1130-
# callable but not in template callable?
1131-
res.extend(
1132-
infer_callable_arguments_constraints(template, cactual, self.direction)
1133-
)
11341142
else:
11351143
prefix = param_spec.prefix
11361144
prefix_len = len(prefix.arg_types)
@@ -1464,6 +1472,7 @@ def repack_callable_args(callable: CallableType, tuple_type: TypeInfo) -> list[T
14641472
list with unpack in the middle, and prefix/suffix on the sides (as they would appear
14651473
in e.g. a TupleType).
14661474
"""
1475+
# TODO: don't repack kw-only args, e.g. with `(a: int, *, b: int)`
14671476
if ARG_STAR not in callable.arg_kinds:
14681477
return callable.arg_types
14691478
star_index = callable.arg_kinds.index(ARG_STAR)

mypy/subtypes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,15 +1663,16 @@ def are_parameters_compatible(
16631663

16641664
trivial_vararg_suffix = False
16651665
if (
1666-
right.arg_kinds[-1:] == [ARG_STAR]
1667-
and isinstance(get_proper_type(right.arg_types[-1]), AnyType)
1666+
right_star is not None
1667+
and isinstance(get_proper_type(right_star.typ), AnyType)
16681668
and not is_proper_subtype
1669-
and all(k.is_positional(star=True) for k in left.arg_kinds)
16701669
):
16711670
# Similar to how (*Any, **Any) is considered a supertype of all callables, we consider
16721671
# (*Any) a supertype of all callables with positional arguments. This is needed in
16731672
# particular because we often refuse to try type inference if actual type is not
16741673
# a subtype of erased template type.
1674+
1675+
# This case also ensures that *Any is any length.
16751676
trivial_vararg_suffix = True
16761677

16771678
# Match up corresponding arguments and check them for compatibility. In

test-data/unit/check-typevar-tuple.test

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,18 +2312,18 @@ def good2(*args: str) -> int: ...
23122312
# These are special-cased for *args: Any (as opposite to *args: object)
23132313
def ok1(a: str, b: int, /) -> None: ...
23142314
def ok2(c: bytes, *args: int) -> str: ...
2315+
def ok3(**kwargs: None) -> None: ...
23152316

23162317
def bad1(*, d: str) -> int: ...
2317-
def bad2(**kwargs: None) -> None: ...
23182318

23192319
higher_order(good1)
23202320
higher_order(good2)
23212321

23222322
higher_order(ok1)
23232323
higher_order(ok2)
2324+
higher_order(ok3)
23242325

23252326
higher_order(bad1) # E: Argument 1 to "higher_order" has incompatible type "Callable[[NamedArg(str, 'd')], int]"; expected "Callable[[VarArg(Any)], Any]"
2326-
higher_order(bad2) # E: Argument 1 to "higher_order" has incompatible type "Callable[[KwArg(None)], None]"; expected "Callable[[VarArg(Any)], Any]"
23272327
[builtins fixtures/tuple.pyi]
23282328

23292329
[case testAliasToCallableWithUnpack2]
@@ -2381,11 +2381,11 @@ def func(x: Array[Unpack[Ts]], *args: Unpack[Ts]) -> Tuple[Unpack[Ts]]:
23812381
...
23822382

23832383
def a2(x: Array[int, str]) -> None:
2384-
reveal_type(func(x, 2, "Hello")) # N: Revealed type is "Tuple[builtins.int, builtins.str]"
2385-
reveal_type(func(x, 2)) # E: Cannot infer type argument 1 of "func" \
2386-
# N: Revealed type is "builtins.tuple[Any, ...]"
2387-
reveal_type(func(x, 2, "Hello", True)) # E: Cannot infer type argument 1 of "func" \
2388-
# N: Revealed type is "builtins.tuple[Any, ...]"
2384+
reveal_type(func(x, 2, "Hello")) # N: Revealed type is "Tuple[Literal[2]?, Literal['Hello']?]"
2385+
reveal_type(func(x, 2)) # N: Revealed type is "Tuple[Literal[2]?]" \
2386+
# E: Argument 1 to "func" has incompatible type "Array[int, str]"; expected "Array[int]"
2387+
reveal_type(func(x, 2, "Hello", True)) # N: Revealed type is "Tuple[Literal[2]?, Literal['Hello']?, Literal[True]?]" \
2388+
# E: Argument 1 to "func" has incompatible type "Array[int, str]"; expected "Array[int, str, bool]"
23892389
[builtins fixtures/tuple.pyi]
23902390

23912391
[case testTypeVarTupleTypeApplicationOverload]
@@ -2628,3 +2628,37 @@ def fn(f: Callable[[*tuple[T]], int]) -> Callable[[*tuple[T]], int]: ...
26282628
def test(*args: Unpack[tuple[T]]) -> int: ...
26292629
reveal_type(fn(test)) # N: Revealed type is "def [T] (T`1) -> builtins.int"
26302630
[builtins fixtures/tuple.pyi]
2631+
2632+
[case testKwargWithTypeVarTupleInference]
2633+
# https://github.com/python/mypy/issues/16522
2634+
from typing import Generic, TypeVar, Protocol
2635+
from typing_extensions import TypeVarTuple, Unpack
2636+
2637+
PosArgT = TypeVarTuple("PosArgT")
2638+
StatusT = TypeVar("StatusT")
2639+
StatusT_co = TypeVar("StatusT_co", covariant=True)
2640+
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
2641+
2642+
class TaskStatus(Generic[StatusT_contra]):
2643+
def started(self, value: StatusT_contra) -> None: ...
2644+
2645+
class NurseryStartFunc(Protocol[Unpack[PosArgT], StatusT_co]):
2646+
def __call__(
2647+
self,
2648+
*args: Unpack[PosArgT],
2649+
task_status: TaskStatus[StatusT_co],
2650+
) -> object: ...
2651+
2652+
def nursery_start(
2653+
async_fn: NurseryStartFunc[Unpack[PosArgT], StatusT],
2654+
*args: Unpack[PosArgT],
2655+
) -> StatusT: ...
2656+
2657+
def task(a: int, b: str, *, task_status: TaskStatus[str]) -> None: ...
2658+
2659+
def test() -> None:
2660+
reveal_type(nursery_start(task, "a", 2)) # N: Revealed type is "builtins.str" \
2661+
# E: Argument 1 to "nursery_start" has incompatible type "Callable[[int, str, NamedArg(TaskStatus[str], 'task_status')], None]"; expected "NurseryStartFunc[str, int, str]" \
2662+
# N: "NurseryStartFunc[str, int, str].__call__" has type "Callable[[str, int, NamedArg(TaskStatus[str], 'task_status')], object]"
2663+
reveal_type(nursery_start(task, 1, "b")) # N: Revealed type is "builtins.str"
2664+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)