44
55from collections .abc import Sequence
66from typing import TYPE_CHECKING , Callable
7+ from typing_extensions import TypeGuard
78
89from mypy import nodes
910from mypy .maptype import map_instance_to_supertype
11+ from mypy .typeops import make_simplified_union
1012from mypy .types import (
1113 AnyType ,
1214 Instance ,
1315 ParamSpecType ,
16+ ProperType ,
1417 TupleType ,
1518 Type ,
1619 TypedDictType ,
1720 TypeOfAny ,
1821 TypeVarTupleType ,
22+ UnionType ,
1923 UnpackType ,
2024 get_proper_type ,
2125)
@@ -54,6 +58,16 @@ def map_actuals_to_formals(
5458 elif actual_kind == nodes .ARG_STAR :
5559 # We need to know the actual type to map varargs.
5660 actualt = get_proper_type (actual_arg_type (ai ))
61+
62+ # Special case for union of equal sized tuples.
63+ if (
64+ isinstance (actualt , UnionType )
65+ and actualt .items
66+ and is_equal_sized_tuples (
67+ proper_types := [get_proper_type (t ) for t in actualt .items ]
68+ )
69+ ):
70+ actualt = proper_types [0 ]
5771 if isinstance (actualt , TupleType ):
5872 # A tuple actual maps to a fixed number of formals.
5973 for _ in range (len (actualt .items )):
@@ -171,6 +185,15 @@ def __init__(self, context: ArgumentInferContext) -> None:
171185 # Type context for `*` and `**` arg kinds.
172186 self .context = context
173187
188+ def __eq__ (self , other : object ) -> bool :
189+ if isinstance (other , ArgTypeExpander ):
190+ return (
191+ self .tuple_index == other .tuple_index
192+ and self .kwargs_used == other .kwargs_used
193+ and self .context == other .context
194+ )
195+ return NotImplemented
196+
174197 def expand_actual_type (
175198 self ,
176199 actual_type : Type ,
@@ -193,6 +216,66 @@ def expand_actual_type(
193216 original_actual = actual_type
194217 actual_type = get_proper_type (actual_type )
195218 if actual_kind == nodes .ARG_STAR :
219+ if isinstance (actual_type , UnionType ):
220+ proper_types = [get_proper_type (t ) for t in actual_type .items ]
221+ # special case: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
222+ if is_equal_sized_tuples (proper_types ):
223+ # transform union of tuples into a tuple of unions
224+ # e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None]
225+ tuple_args : list [Type ] = [
226+ make_simplified_union (items )
227+ for items in zip (* (t .items for t in proper_types ))
228+ ]
229+ actual_type = TupleType (
230+ tuple_args ,
231+ # use Iterable[A | B | C] as the fallback type
232+ fallback = Instance (
233+ self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
234+ ),
235+ )
236+ else :
237+ # reinterpret all union items as iterable types (if possible)
238+ # and return the union of the iterable item types results.
239+ from mypy .subtypes import is_subtype
240+
241+ iterable_type = self .context .iterable_type
242+
243+ def as_iterable_type (t : Type ) -> Type :
244+ """Map a type to the iterable supertype if it is a subtype."""
245+ p_t = get_proper_type (t )
246+ if isinstance (p_t , Instance ) and is_subtype (t , iterable_type ):
247+ return map_instance_to_supertype (p_t , iterable_type .type )
248+ if isinstance (p_t , TupleType ):
249+ # Convert tuple[A, B, C] to Iterable[A | B | C].
250+ return Instance (iterable_type .type , [make_simplified_union (p_t .items )])
251+ return t
252+
253+ # create copies of self for each item in the union
254+ sub_expanders = [
255+ ArgTypeExpander (context = self .context ) for _ in actual_type .items
256+ ]
257+ for expander in sub_expanders :
258+ expander .tuple_index = int (self .tuple_index )
259+ expander .kwargs_used = set (self .kwargs_used )
260+
261+ candidate_type = make_simplified_union (
262+ [
263+ e .expand_actual_type (
264+ as_iterable_type (item ),
265+ actual_kind ,
266+ formal_name ,
267+ formal_kind ,
268+ allow_unpack ,
269+ )
270+ for e , item in zip (sub_expanders , actual_type .items )
271+ ]
272+ )
273+ assert all (expander == sub_expanders [0 ] for expander in sub_expanders )
274+ # carry over the new state if all sub-expanders are the same state
275+ self .tuple_index = int (sub_expanders [0 ].tuple_index )
276+ self .kwargs_used = set (sub_expanders [0 ].kwargs_used )
277+ return candidate_type
278+
196279 if isinstance (actual_type , TypeVarTupleType ):
197280 # This code path is hit when *Ts is passed to a callable and various
198281 # special-handling didn't catch this. The best thing we can do is to use
@@ -265,3 +348,20 @@ def expand_actual_type(
265348 else :
266349 # No translation for other kinds -- 1:1 mapping.
267350 return original_actual
351+
352+
353+ def is_equal_sized_tuples (types : Sequence [ProperType ]) -> TypeGuard [Sequence [TupleType ]]:
354+ """Check if all types are tuples of the same size."""
355+ if not types :
356+ return True
357+
358+ iterator = iter (types )
359+ first = next (iterator )
360+ if not isinstance (first , TupleType ):
361+ return False
362+ size = first .length ()
363+
364+ for item in iterator :
365+ if not isinstance (item , TupleType ) or item .length () != size :
366+ return False
367+ return True
0 commit comments