Skip to content

Commit 59035f0

Browse files
committed
Better handling of Any/object in variadic generics
1 parent 29ffa3e commit 59035f0

File tree

4 files changed

+130
-23
lines changed

4 files changed

+130
-23
lines changed

mypy/erasetype.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ def visit_tuple_type(self, t: TupleType) -> Type:
203203
return unpacked
204204
return result
205205

206+
def visit_callable_type(self, t: CallableType) -> Type:
207+
result = super().visit_callable_type(t)
208+
assert isinstance(result, ProperType) and isinstance(result, CallableType)
209+
# Usually this is done in semanal_typeargs.py, but erasure can create
210+
# a non-normal callable from normal one.
211+
result.normalize_trivial_unpack()
212+
return result
213+
206214
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
207215
if self.erase_id(t.id):
208216
return t.tuple_fallback.copy_modified(args=[self.replacement])

mypy/expandtype.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def visit_instance(self, t: Instance) -> Type:
226226
if isinstance(arg, UnpackType):
227227
unpacked = get_proper_type(arg.type)
228228
if isinstance(unpacked, Instance):
229+
# TODO: this and similar asserts below may be unsafe because get_proper_type()
230+
# may be called during semantic analysis before all invalid types are removed.
229231
assert unpacked.type.fullname == "builtins.tuple"
230232
args = list(unpacked.args)
231233
return t.copy_modified(args=args)
@@ -333,10 +335,7 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l
333335

334336
var_arg_type = get_proper_type(var_arg.type)
335337
new_unpack: Type
336-
if isinstance(var_arg_type, Instance):
337-
# we have something like Unpack[Tuple[Any, ...]]
338-
new_unpack = UnpackType(var_arg.type.accept(self))
339-
elif isinstance(var_arg_type, TupleType):
338+
if isinstance(var_arg_type, TupleType):
340339
# We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
341340
expanded_tuple = var_arg_type.accept(self)
342341
assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
@@ -348,6 +347,11 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l
348347
fallback = var_arg_type.tuple_fallback
349348
expanded_items = self.expand_unpack(var_arg)
350349
new_unpack = UnpackType(TupleType(expanded_items, fallback))
350+
# Since get_proper_type() may be called in semanal.py before callable
351+
# normalization happens, we need to also handle non-normal cases here.
352+
elif isinstance(var_arg_type, Instance):
353+
# we have something like Unpack[Tuple[Any, ...]]
354+
new_unpack = UnpackType(var_arg.type.accept(self))
351355
else:
352356
# We have invalid type in Unpack. This can happen when expanding aliases
353357
# to Callable[[*Invalid], Ret]

mypy/subtypes.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterator
3+
from collections.abc import Iterable, Iterator
44
from contextlib import contextmanager
55
from typing import Any, Callable, Final, TypeVar, cast
66
from typing_extensions import TypeAlias as _TypeAlias
@@ -414,6 +414,9 @@ def _is_subtype(self, left: Type, right: Type) -> bool:
414414
return is_proper_subtype(left, right, subtype_context=self.subtype_context)
415415
return is_subtype(left, right, subtype_context=self.subtype_context)
416416

417+
def _all_subtypes(self, lefts: Iterable[Type], rights: Iterable[Type]) -> bool:
418+
return all(self._is_subtype(li, ri) for (li, ri) in zip(lefts, rights))
419+
417420
# visit_x(left) means: is left (which is an instance of X) a subtype of right?
418421

419422
def visit_unbound_type(self, left: UnboundType) -> bool:
@@ -856,11 +859,25 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
856859
# There are some items on the left that will never have a matching length
857860
# on the right.
858861
return False
862+
left_prefix = left_unpack_index
863+
left_suffix = len(left.items) - left_prefix - 1
859864
left_unpack = left.items[left_unpack_index]
860865
assert isinstance(left_unpack, UnpackType)
861866
left_unpacked = get_proper_type(left_unpack.type)
862867
if not isinstance(left_unpacked, Instance):
863-
# *Ts unpacks can't be split.
868+
# *Ts unpack can't be split, except if it is all mapped to Anys or objects.
869+
if self.is_top_type(right_item):
870+
right_prefix_types, middle, right_suffix_types = split_with_prefix_and_suffix(
871+
tuple(right.items), left_prefix, left_suffix
872+
)
873+
if not all(
874+
self.is_top_type(ri) or isinstance(ri, UnpackType) for ri in middle
875+
):
876+
return False
877+
# Also check the tails match as well.
878+
return self._all_subtypes(
879+
left.items[:left_prefix], right_prefix_types
880+
) and self._all_subtypes(left.items[-left_suffix:], right_suffix_types)
864881
return False
865882
assert left_unpacked.type.fullname == "builtins.tuple"
866883
left_item = left_unpacked.args[0]
@@ -871,8 +888,6 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
871888
# and then check subtyping for all finite overlaps.
872889
if not self._is_subtype(left_item, right_item):
873890
return False
874-
left_prefix = left_unpack_index
875-
left_suffix = len(left.items) - left_prefix - 1
876891
max_overlap = max(0, right_prefix - left_prefix, right_suffix - left_suffix)
877892
for overlap in range(max_overlap + 1):
878893
repr_items = left.items[:left_prefix] + [left_item] * overlap
@@ -883,6 +898,11 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
883898
return False
884899
return True
885900

901+
def is_top_type(self, typ: Type) -> bool:
902+
if not self.proper_subtype and isinstance(get_proper_type(typ), AnyType):
903+
return True
904+
return is_named_instance(typ, "builtins.object")
905+
886906
def visit_typeddict_type(self, left: TypedDictType) -> bool:
887907
right = self.right
888908
if isinstance(right, Instance):
@@ -1653,17 +1673,18 @@ def are_parameters_compatible(
16531673
return True
16541674
trivial_suffix = is_trivial_suffix(right) and not is_proper_subtype
16551675

1676+
trivial_vararg_suffix = False
16561677
if (
1657-
right.arg_kinds == [ARG_STAR]
1658-
and isinstance(get_proper_type(right.arg_types[0]), AnyType)
1678+
right.arg_kinds[-1:] == [ARG_STAR]
1679+
and isinstance(get_proper_type(right.arg_types[-1]), AnyType)
16591680
and not is_proper_subtype
1681+
and all(k.is_positional(star=True) for k in left.arg_kinds)
16601682
):
16611683
# Similar to how (*Any, **Any) is considered a supertype of all callables, we consider
16621684
# (*Any) a supertype of all callables with positional arguments. This is needed in
16631685
# particular because we often refuse to try type inference if actual type is not
16641686
# a subtype of erased template type.
1665-
if all(k.is_positional() for k in left.arg_kinds) and ignore_pos_arg_names:
1666-
return True
1687+
trivial_vararg_suffix = True
16671688

16681689
# Match up corresponding arguments and check them for compatibility. In
16691690
# every pair (argL, argR) of corresponding arguments from L and R, argL must
@@ -1697,7 +1718,11 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
16971718
return not allow_partial_overlap and not trivial_suffix
16981719
return not is_compat(right_arg.typ, left_arg.typ)
16991720

1700-
if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
1721+
if (
1722+
_incompatible(left_star, right_star)
1723+
and not trivial_vararg_suffix
1724+
or _incompatible(left_star2, right_star2)
1725+
):
17011726
return False
17021727

17031728
# Phase 1b: Check non-star args: for every arg right can accept, left must
@@ -1727,8 +1752,8 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
17271752
# Phase 1c: Check var args. Right has an infinite series of optional positional
17281753
# arguments. Get all further positional args of left, and make sure
17291754
# they're more general than the corresponding member in right.
1730-
# TODO: are we handling UnpackType correctly here?
1731-
if right_star is not None:
1755+
# TODO: handle suffix in UnpackType (i.e. *args: *Tuple[Ts, X, Y]).
1756+
if right_star is not None and not trivial_vararg_suffix:
17321757
# Synthesize an anonymous formal argument for the right
17331758
right_by_position = right.try_synthesizing_arg_from_vararg(None)
17341759
assert right_by_position is not None

test-data/unit/check-typevar-tuple.test

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,18 +2305,21 @@ def higher_order(f: _CallableValue) -> None: ...
23052305
def good1(*args: int) -> None: ...
23062306
def good2(*args: str) -> int: ...
23072307

2308-
def bad1(a: str, b: int, /) -> None: ...
2309-
def bad2(c: bytes, *args: int) -> str: ...
2310-
def bad3(*, d: str) -> int: ...
2311-
def bad4(**kwargs: None) -> None: ...
2308+
# These are special-cased for *args: Any (as opposite to *args: object)
2309+
def ok1(a: str, b: int, /) -> None: ...
2310+
def ok2(c: bytes, *args: int) -> str: ...
2311+
2312+
def bad1(*, d: str) -> int: ...
2313+
def bad2(**kwargs: None) -> None: ...
23122314

23132315
higher_order(good1)
23142316
higher_order(good2)
23152317

2316-
higher_order(bad1) # E: Argument 1 to "higher_order" has incompatible type "Callable[[str, int], None]"; expected "Callable[[VarArg(Any)], Any]"
2317-
higher_order(bad2) # E: Argument 1 to "higher_order" has incompatible type "Callable[[bytes, VarArg(int)], str]"; expected "Callable[[VarArg(Any)], Any]"
2318-
higher_order(bad3) # E: Argument 1 to "higher_order" has incompatible type "Callable[[NamedArg(str, 'd')], int]"; expected "Callable[[VarArg(Any)], Any]"
2319-
higher_order(bad4) # E: Argument 1 to "higher_order" has incompatible type "Callable[[KwArg(None)], None]"; expected "Callable[[VarArg(Any)], Any]"
2318+
higher_order(ok1)
2319+
higher_order(ok2)
2320+
2321+
higher_order(bad1) # E: Argument 1 to "higher_order" has incompatible type "Callable[[NamedArg(str, 'd')], int]"; expected "Callable[[VarArg(Any)], Any]"
2322+
higher_order(bad2) # E: Argument 1 to "higher_order" has incompatible type "Callable[[KwArg(None)], None]"; expected "Callable[[VarArg(Any)], Any]"
23202323
[builtins fixtures/tuple.pyi]
23212324

23222325
[case testAliasToCallableWithUnpack2]
@@ -2513,3 +2516,70 @@ x4: Foo[Unpack[tuple[str, ...]]]
25132516
y4: Foo[Unpack[tuple[int, int]]]
25142517
x4 is y4 # E: Non-overlapping identity check (left operand type: "Foo[Unpack[Tuple[str, ...]]]", right operand type: "Foo[int, int]")
25152518
[builtins fixtures/tuple.pyi]
2519+
2520+
[case testTypeVarTupleErasureNormalized]
2521+
from typing import TypeVarTuple, Unpack, Generic, Union
2522+
from collections.abc import Callable
2523+
2524+
Args = TypeVarTuple("Args")
2525+
2526+
class Built(Generic[Unpack[Args]]):
2527+
pass
2528+
2529+
def example(
2530+
fn: Union[Built[Unpack[Args]], Callable[[Unpack[Args]], None]]
2531+
) -> Built[Unpack[Args]]: ...
2532+
2533+
@example
2534+
def command() -> None:
2535+
return
2536+
reveal_type(command) # N: Revealed type is "__main__.Built[()]"
2537+
[builtins fixtures/tuple.pyi]
2538+
2539+
[case testTypeVarTupleSelfMappedPrefix]
2540+
from typing import TypeVarTuple, Generic
2541+
2542+
Ts = TypeVarTuple("Ts")
2543+
class Base(Generic[*Ts]):
2544+
attr: tuple[*Ts]
2545+
2546+
@property
2547+
def prop(self) -> tuple[*Ts]:
2548+
return self.attr
2549+
2550+
def meth(self) -> tuple[*Ts]:
2551+
return self.attr
2552+
2553+
Ss = TypeVarTuple("Ss")
2554+
class Derived(Base[str, *Ss]):
2555+
def test(self) -> None:
2556+
reveal_type(self.attr) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]"
2557+
reveal_type(self.prop) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]"
2558+
reveal_type(self.meth()) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]"
2559+
[builtins fixtures/property.pyi]
2560+
2561+
[case testTypeVarTupleProtocolPrefix]
2562+
from typing import Protocol, Unpack, TypeVarTuple
2563+
2564+
Ts = TypeVarTuple("Ts")
2565+
class A(Protocol[Unpack[Ts]]):
2566+
def f(self, z: str, *args: Unpack[Ts]) -> None: ...
2567+
2568+
class C:
2569+
def f(self, z: str, x: int) -> None: ...
2570+
2571+
def f(x: A[Unpack[Ts]]) -> tuple[Unpack[Ts]]: ...
2572+
2573+
reveal_type(f(C())) # N: Revealed type is "Tuple[builtins.int]"
2574+
[builtins fixtures/tuple.pyi]
2575+
2576+
[case testTypeVarTupleHomogeneousCallableNormalized]
2577+
from typing import Generic, Unpack, TypeVarTuple
2578+
2579+
Ts = TypeVarTuple("Ts")
2580+
class C(Generic[Unpack[Ts]]):
2581+
def foo(self, *args: Unpack[Ts]) -> None: ...
2582+
2583+
c: C[Unpack[tuple[int, ...]]]
2584+
reveal_type(c.foo) # N: Revealed type is "def (*args: builtins.int)"
2585+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)