33from __future__ import annotations
44
55from collections .abc import Sequence
6- from typing import TYPE_CHECKING , Callable
7- from typing_extensions import TypeGuard
6+ from typing import TYPE_CHECKING , Callable , cast
7+ from typing_extensions import NewType , TypeGuard
88
99from mypy import nodes
1010from mypy .maptype import map_instance_to_supertype
1111from mypy .typeops import make_simplified_union
1212from mypy .types import (
1313 AnyType ,
14+ CallableType ,
1415 Instance ,
1516 ParamSpecType ,
1617 ProperType ,
1718 TupleType ,
1819 Type ,
1920 TypedDictType ,
2021 TypeOfAny ,
22+ TypeVarId ,
2123 TypeVarTupleType ,
24+ TypeVarType ,
2225 UnionType ,
2326 UnpackType ,
27+ flatten_nested_tuples ,
2428 get_proper_type ,
2529)
2630
2731if TYPE_CHECKING :
2832 from mypy .infer import ArgumentInferContext
2933
3034
35+ IterableType = NewType ("IterableType" , Instance )
36+ """Represents an instance of `Iterable[T]`."""
37+
38+
3139def map_actuals_to_formals (
3240 actual_kinds : list [nodes .ArgKind ],
3341 actual_names : Sequence [str | None ] | None ,
@@ -216,92 +224,41 @@ def expand_actual_type(
216224 original_actual = actual_type
217225 actual_type = get_proper_type (actual_type )
218226 if actual_kind == nodes .ARG_STAR :
219- if isinstance (actual_type , UnionType ):
220- proper_types = [get_proper_type (t ) for t in actual_type .items ]
221- # special case: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
222- if is_equal_sized_tuples (proper_types ):
223- # transform union of tuples into a tuple of unions
224- # e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None]
225- tuple_args : list [Type ] = [
226- make_simplified_union (items )
227- for items in zip (* (t .items for t in proper_types ))
228- ]
229- actual_type = TupleType (
230- tuple_args ,
231- # use Iterable[A | B | C] as the fallback type
232- fallback = Instance (
233- self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
234- ),
235- )
236- else :
237- # reinterpret all union items as iterable types (if possible)
238- # and return the union of the iterable item types results.
239- from mypy .subtypes import is_subtype
240-
241- iterable_type = self .context .iterable_type
242-
243- def as_iterable_type (t : Type ) -> Type :
244- """Map a type to the iterable supertype if it is a subtype."""
245- p_t = get_proper_type (t )
246- if isinstance (p_t , Instance ) and is_subtype (t , iterable_type ):
247- return map_instance_to_supertype (p_t , iterable_type .type )
248- if isinstance (p_t , TupleType ):
249- # Convert tuple[A, B, C] to Iterable[A | B | C].
250- return Instance (iterable_type .type , [make_simplified_union (p_t .items )])
251- return t
252-
253- # create copies of self for each item in the union
254- sub_expanders = [
255- ArgTypeExpander (context = self .context ) for _ in actual_type .items
256- ]
257- for expander in sub_expanders :
258- expander .tuple_index = int (self .tuple_index )
259- expander .kwargs_used = set (self .kwargs_used )
260-
261- candidate_type = make_simplified_union (
262- [
263- e .expand_actual_type (
264- as_iterable_type (item ),
265- actual_kind ,
266- formal_name ,
267- formal_kind ,
268- allow_unpack ,
269- )
270- for e , item in zip (sub_expanders , actual_type .items )
271- ]
272- )
273- assert all (expander == sub_expanders [0 ] for expander in sub_expanders )
274- # carry over the new state if all sub-expanders are the same state
275- self .tuple_index = int (sub_expanders [0 ].tuple_index )
276- self .kwargs_used = set (sub_expanders [0 ].kwargs_used )
277- return candidate_type
278-
279- if isinstance (actual_type , TypeVarTupleType ):
280- # This code path is hit when *Ts is passed to a callable and various
281- # special-handling didn't catch this. The best thing we can do is to use
282- # the upper bound.
283- actual_type = get_proper_type (actual_type .upper_bound )
284- if isinstance (actual_type , Instance ) and actual_type .args :
285- from mypy .subtypes import is_subtype
286-
287- if is_subtype (actual_type , self .context .iterable_type ):
288- return map_instance_to_supertype (
289- actual_type , self .context .iterable_type .type
290- ).args [0 ]
291- else :
292- # We cannot properly unpack anything other
293- # than `Iterable` type with `*`.
294- # Just return `Any`, other parts of code would raise
295- # a different error for improper use.
296- return AnyType (TypeOfAny .from_error )
297- elif isinstance (actual_type , TupleType ):
227+ # parse *args as one of the following:
228+ # IterableType | TupleType | ParamSpecType | AnyType
229+ star_args = self .parse_star_args_type (actual_type )
230+ # star_args = actual_type
231+
232+ # print(f"expand_actual_type: {actual_type=} {star_args=}")
233+
234+ # if isinstance(star_args, TypeVarTupleType):
235+ # # This code path is hit when *Ts is passed to a callable and various
236+ # # special-handling didn't catch this. The best thing we can do is to use
237+ # # the upper bound.
238+ # star_args = get_proper_type(star_args.upper_bound)
239+ # if isinstance(star_args, Instance) and star_args.args:
240+ # from mypy.subtypes import is_subtype
241+ #
242+ # if is_subtype(star_args, self.context.iterable_type):
243+ # return map_instance_to_supertype(
244+ # star_args, self.context.iterable_type.type
245+ # ).args[0]
246+ # else:
247+ # # We cannot properly unpack anything other
248+ # # than `Iterable` type with `*`.
249+ # # Just return `Any`, other parts of code would raise
250+ # # a different error for improper use.
251+ # return AnyType(TypeOfAny.from_error)
252+ if self .is_iterable_type (star_args ):
253+ return star_args .args [0 ]
254+ elif isinstance (star_args , TupleType ):
298255 # Get the next tuple item of a tuple *arg.
299- if self .tuple_index >= len (actual_type .items ):
256+ if self .tuple_index >= len (star_args .items ):
300257 # Exhausted a tuple -- continue to the next *args.
301258 self .tuple_index = 1
302259 else :
303260 self .tuple_index += 1
304- item = actual_type .items [self .tuple_index - 1 ]
261+ item = star_args .items [self .tuple_index - 1 ]
305262 if isinstance (item , UnpackType ) and not allow_unpack :
306263 # An unpack item that doesn't have special handling, use upper bound as above.
307264 unpacked = get_proper_type (item .type )
@@ -315,9 +272,9 @@ def as_iterable_type(t: Type) -> Type:
315272 )
316273 item = fallback .args [0 ]
317274 return item
318- elif isinstance (actual_type , ParamSpecType ):
275+ elif isinstance (star_args , ParamSpecType ):
319276 # ParamSpec is valid in *args but it can't be unpacked.
320- return actual_type
277+ return star_args
321278 else :
322279 return AnyType (TypeOfAny .from_error )
323280 elif actual_kind == nodes .ARG_STAR2 :
@@ -349,19 +306,197 @@ def as_iterable_type(t: Type) -> Type:
349306 # No translation for other kinds -- 1:1 mapping.
350307 return original_actual
351308
309+ def is_iterable (self , typ : Type ) -> bool :
310+ from mypy .subtypes import is_subtype
311+
312+ return is_subtype (typ , self .context .iterable_type )
313+
314+ def is_iterable_instance_subtype (self , typ : Type ) -> TypeGuard [Instance ]:
315+ from mypy .subtypes import is_subtype
316+
317+ p_t = get_proper_type (typ )
318+ return (
319+ isinstance (p_t , Instance )
320+ and bool (p_t .args )
321+ and is_subtype (p_t , self .context .iterable_type )
322+ )
323+
324+ def is_iterable_type (self , typ : Type ) -> TypeGuard [IterableType ]:
325+ """Check if the type is an Iterable[T] or a subtype of it."""
326+ p_t = get_proper_type (typ )
327+ return isinstance (p_t , Instance ) and p_t .type == self .context .iterable_type .type
328+
329+ def as_iterable_type (self , typ : Type ) -> IterableType | AnyType :
330+ """Reinterpret a type as Iterable[T], or return AnyType if not possible."""
331+ p_t = get_proper_type (typ )
332+ if self .is_iterable_type (p_t ):
333+ return p_t
334+ elif self .is_iterable_instance_subtype (p_t ):
335+ cls = self .context .iterable_type .type
336+ return cast (IterableType , map_instance_to_supertype (p_t , cls ))
337+ elif isinstance (p_t , UnionType ):
338+ # If the type is a union, map each item to the iterable supertype.
339+ # the return the combined iterable type Iterable[A] | Iterable[B] -> Iterable[A | B]
340+ converted_types = [self .as_iterable_type (get_proper_type (item )) for item in p_t .items ]
341+ # if an item could not be interpreted as Iterable[T], we return AnyType
342+ if all (self .is_iterable_type (it ) for it in converted_types ):
343+ # all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
344+ iterable_types = cast (list [IterableType ], converted_types )
345+ arg = make_simplified_union ([it .args [0 ] for it in iterable_types ])
346+ return self .make_iterable_type (arg )
347+ return AnyType (TypeOfAny .from_error )
348+ elif isinstance (p_t , TupleType ):
349+ # maps tuple[A, B, C] -> Iterable[A | B | C]
350+ # note: proper_elements may contain UnpackType, for instance with
351+ # tuple[None, *tuple[None, ...]]..
352+ proper_elements = [get_proper_type (t ) for t in flatten_nested_tuples (p_t .items )]
353+ args : list [Type ] = []
354+ for p_e in proper_elements :
355+ if isinstance (p_e , UnpackType ):
356+ r = self .as_iterable_type (p_e )
357+ if self .is_iterable_type (r ):
358+ args .append (r .args [0 ])
359+ else :
360+ args .append (r )
361+ else :
362+ args .append (p_e )
363+ return self .make_iterable_type (make_simplified_union (args ))
364+ if isinstance (p_t , UnpackType ):
365+ return self .as_iterable_type (p_t .type )
366+ if isinstance (p_t , (TypeVarType , TypeVarTupleType )):
367+ return self .as_iterable_type (p_t .upper_bound )
368+ # fallback: use the solver to reinterpret the type as Iterable[T]
369+ if self .is_iterable (p_t ):
370+ return self ._solve_as_iterable (p_t )
371+ return AnyType (TypeOfAny .from_error )
372+
373+ def make_iterable_type (self , arg : Type ) -> IterableType :
374+ value = Instance (self .context .iterable_type .type , [arg ])
375+ return cast (IterableType , value )
376+
377+ def parse_star_args_type (
378+ self , typ : Type
379+ ) -> TupleType | IterableType | ParamSpecType | AnyType :
380+ """Parse the type of a *args argument.
381+
382+ Returns one TupleType, IterableType, ParamSpecType or AnyType.
383+ """
384+ p_t = get_proper_type (typ )
385+ if isinstance (p_t , (TupleType , ParamSpecType , AnyType )):
386+ # just return the type as-is
387+ return p_t
388+ elif isinstance (p_t , TypeVarTupleType ):
389+ return self .parse_star_args_type (p_t .upper_bound )
390+ elif isinstance (p_t , UnionType ):
391+ proper_items = [get_proper_type (t ) for t in p_t .items ]
392+ # consider 2 cases:
393+ # 1. Union of equal sized tuples, e.g. tuple[A, B] | tuple[None, None]
394+ # In this case transform union of same-sized tuples into a tuple of unions
395+ # e.g. tuple[A, B] | tuple[None, None] -> tuple[A | None, B | None]
396+ if is_equal_sized_tuples (proper_items ):
397+
398+ tuple_args : list [Type ] = [
399+ make_simplified_union (items ) for items in zip (* (t .items for t in proper_items ))
400+ ]
401+ actual_type = TupleType (
402+ tuple_args ,
403+ # use Iterable[A | B | C] as the fallback type
404+ fallback = Instance (
405+ self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
406+ ),
407+ )
408+ return actual_type
409+ # 2. Union of iterable types, e.g. Iterable[A] | Iterable[B]
410+ # In this case return Iterable[A | B]
411+ # Note that this covers unions of differently sized tuples as well.
412+ else :
413+ converted_types = [self .as_iterable_type (p_i ) for p_i in proper_items ]
414+ if all (self .is_iterable_type (it ) for it in converted_types ):
415+ # all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
416+ iterables = cast (list [IterableType ], converted_types )
417+ arg = make_simplified_union ([it .args [0 ] for it in iterables ])
418+ return self .make_iterable_type (arg )
419+ else :
420+ # some items in the union are not iterable, return AnyType
421+ return AnyType (TypeOfAny .from_error )
422+ elif self .is_iterable_type (parsed := self .as_iterable_type (p_t )):
423+ # in all other cases, we try to reinterpret the type as Iterable[T]
424+ return parsed
425+ return AnyType (TypeOfAny .from_error )
426+
427+ def _solve_as_iterable (self , typ : Type ) -> IterableType | AnyType :
428+ r"""Use the solver to cast a type as Iterable[T].
429+
430+ Returns the type as-is if solving fails.
431+ """
432+ from mypy .constraints import infer_constraints_for_callable
433+ from mypy .nodes import ARG_POS
434+ from mypy .solve import solve_constraints
435+
436+ iterable_kind = self .context .iterable_type .type
437+
438+ # We first create an upcast function:
439+ # def [T] (Iterable[T]) -> Iterable[T]: ...
440+ # and then solve for T, given the input type as the argument.
441+ T = TypeVarType (
442+ "T" ,
443+ "T" ,
444+ TypeVarId (- 1 ),
445+ values = [],
446+ upper_bound = AnyType (TypeOfAny .special_form ),
447+ default = AnyType (TypeOfAny .special_form ),
448+ )
449+ target = Instance (iterable_kind , [T ])
450+
451+ upcast_callable = CallableType (
452+ variables = [T ],
453+ arg_types = [target ],
454+ arg_kinds = [ARG_POS ],
455+ arg_names = [None ],
456+ ret_type = T ,
457+ fallback = self .context .function_type ,
458+ )
459+ constraints = infer_constraints_for_callable (
460+ upcast_callable , [typ ], [ARG_POS ], [None ], [[0 ]], context = self .context
461+ )
462+
463+ (sol ,), _ = solve_constraints ([T ], constraints )
464+
465+ if sol is None : # solving failed, return AnyType fallback
466+ return AnyType (TypeOfAny .from_error )
467+ return self .make_iterable_type (sol )
468+
352469
353470def is_equal_sized_tuples (types : Sequence [ProperType ]) -> TypeGuard [Sequence [TupleType ]]:
354- """Check if all types are tuples of the same size."""
471+ """Check if all types are tuples of the same size.
472+
473+ We use `flatten_nested_tuples` to deal with nested tuples.
474+ Note that the result may still contain
475+ """
355476 if not types :
356477 return True
357478
358479 iterator = iter (types )
359- first = next (iterator )
360- if not isinstance (first , TupleType ):
480+ typ = next (iterator )
481+ if not isinstance (typ , TupleType ):
482+ return False
483+ flattened_elements = flatten_nested_tuples (typ .items )
484+ if any (
485+ isinstance (get_proper_type (member ), (UnpackType , TypeVarTupleType ))
486+ for member in flattened_elements
487+ ):
488+ # this can happen e.g. with tuple[int, *tuple[int, ...], int]
361489 return False
362- size = first . length ( )
490+ size = len ( flattened_elements )
363491
364- for item in iterator :
365- if not isinstance (item , TupleType ) or item .length () != size :
492+ for typ in iterator :
493+ if not isinstance (typ , TupleType ):
494+ return False
495+ flattened_elements = flatten_nested_tuples (typ .items )
496+ if len (flattened_elements ) != size or any (
497+ isinstance (get_proper_type (member ), (UnpackType , TypeVarTupleType ))
498+ for member in flattened_elements
499+ ):
500+ # this can happen e.g. with tuple[int, *tuple[int, ...], int]
366501 return False
367502 return True
0 commit comments