44
55from collections .abc import Sequence
66from typing import TYPE_CHECKING , Callable
7+ from typing_extensions import TypeGuard
78
89from mypy import nodes
9- from mypy .join import join_type_list
1010from mypy .maptype import map_instance_to_supertype
1111from mypy .typeops import make_simplified_union
1212from mypy .types import (
1313 AnyType ,
1414 Instance ,
1515 ParamSpecType ,
16+ ProperType ,
1617 TupleType ,
1718 Type ,
1819 TypedDictType ,
@@ -62,10 +63,11 @@ def map_actuals_to_formals(
6263 if (
6364 isinstance (actualt , UnionType )
6465 and actualt .items
65- and is_equal_sized_tuples (actualt .items )
66+ and is_equal_sized_tuples (
67+ proper_types := [get_proper_type (t ) for t in actualt .items ]
68+ )
6669 ):
67- # Arbitrarily pick the first item in the union.
68- actualt = get_proper_type (actualt .items [0 ])
70+ actualt = proper_types [0 ]
6971 if isinstance (actualt , TupleType ):
7072 # A tuple actual maps to a fixed number of formals.
7173 for _ in range (len (actualt .items )):
@@ -215,18 +217,38 @@ def expand_actual_type(
215217 actual_type = get_proper_type (actual_type )
216218 if actual_kind == nodes .ARG_STAR :
217219 if isinstance (actual_type , UnionType ):
218- # special case 1: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
219- # special case 2: union contains no static sized tuples. (e.g. `list[str | None] | list[str]`)
220- if is_equal_sized_tuples (actual_type .items ) or not any (
221- isinstance (get_proper_type (t ), TupleType ) for t in actual_type .items
222- ):
223- # If the actual type is a union, try expanding it.
224- # Example: f(*args), where args is `list[str | None] | list[str]`,
225- # Example: f(*args), where args is `tuple[A, B, C] | tuple[None, None, None]`
226- # Note: there is potential for combinatorial explosion here:
227- # f(*x1, *x2, .. *xn), if xₖ is a union of nₖ differently sized tuples,
228- # then there are n₁ * n₂ * ... * nₖ possible combinations of pointer positions.
229- # therefore, we only take this branch if all union members consume the same number of items.
220+ proper_types = [get_proper_type (t ) for t in actual_type .items ]
221+ # special case: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
222+ if is_equal_sized_tuples (proper_types ):
223+ # transform union of tuples into a tuple of unions
224+ # e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None]
225+ tuple_args : list [Type ] = [
226+ make_simplified_union (items )
227+ for items in zip (* (t .items for t in proper_types ))
228+ ]
229+ actual_type = TupleType (
230+ tuple_args ,
231+ # use Iterable[A | B | C] as the fallback type
232+ fallback = Instance (
233+ self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
234+ ),
235+ )
236+ else :
237+ # reinterpret all union items as iterable types (if possible)
238+ # and return the union of the iterable item types results.
239+ from mypy .subtypes import is_subtype
240+
241+ iterable_type = self .context .iterable_type
242+
243+ def as_iterable_type (t : Type ) -> Type :
244+ """Map a type to the iterable supertype if it is a subtype."""
245+ p_t = get_proper_type (t )
246+ if isinstance (p_t , Instance ) and is_subtype (t , iterable_type ):
247+ return map_instance_to_supertype (p_t , iterable_type .type )
248+ if isinstance (p_t , TupleType ):
249+ # Convert tuple[A, B, C] to Iterable[A | B | C].
250+ return Instance (iterable_type .type , [make_simplified_union (p_t .items )])
251+ return t
230252
231253 # create copies of self for each item in the union
232254 sub_expanders = [
@@ -239,7 +261,11 @@ def expand_actual_type(
239261 candidate_type = make_simplified_union (
240262 [
241263 e .expand_actual_type (
242- item , actual_kind , formal_name , formal_kind , allow_unpack
264+ as_iterable_type (item ),
265+ actual_kind ,
266+ formal_name ,
267+ formal_kind ,
268+ allow_unpack ,
243269 )
244270 for e , item in zip (sub_expanders , actual_type .items )
245271 ]
@@ -249,28 +275,6 @@ def expand_actual_type(
249275 self .tuple_index = int (sub_expanders [0 ].tuple_index )
250276 self .kwargs_used = set (sub_expanders [0 ].kwargs_used )
251277 return candidate_type
252- else :
253- # otherwise, we fall back to checking using the join of the union members.
254- # for better results we first map all instances to Iterable[T]
255- from mypy .subtypes import is_subtype
256-
257- iterable_type = self .context .iterable_type
258-
259- def as_iterable_type (t : Type ) -> Type :
260- """Map a type to the iterable supertype if it is a subtype."""
261- p_t = get_proper_type (t )
262- if isinstance (p_t , Instance ) and is_subtype (t , iterable_type ):
263- return map_instance_to_supertype (p_t , iterable_type .type )
264- if isinstance (p_t , TupleType ):
265- # Convert tuple[A, B, C] to Iterable[A | B | C].
266- return Instance (iterable_type .type , [make_simplified_union (p_t .items )])
267- return t
268-
269- joined_type = join_type_list ([as_iterable_type (t ) for t in actual_type .items ])
270- assert not isinstance (get_proper_type (joined_type ), TupleType )
271- return self .expand_actual_type (
272- joined_type , actual_kind , formal_name , formal_kind , allow_unpack
273- )
274278
275279 if isinstance (actual_type , TypeVarTupleType ):
276280 # This code path is hit when *Ts is passed to a callable and various
@@ -346,19 +350,18 @@ def as_iterable_type(t: Type) -> Type:
346350 return original_actual
347351
348352
349- def is_equal_sized_tuples (types : Sequence [Type ]) -> bool :
353+ def is_equal_sized_tuples (types : Sequence [ProperType ]) -> TypeGuard [ Sequence [ TupleType ]] :
350354 """Check if all types are tuples of the same size."""
351355 if not types :
352356 return True
353357
354358 iterator = iter (types )
355- first = get_proper_type ( next (iterator ) )
359+ first = next (iterator )
356360 if not isinstance (first , TupleType ):
357361 return False
358362 size = first .length ()
359363
360364 for item in iterator :
361- p_t = get_proper_type (item )
362- if not isinstance (p_t , TupleType ) or p_t .length () != size :
365+ if not isinstance (item , TupleType ) or item .length () != size :
363366 return False
364367 return True
0 commit comments