33from  __future__ import  annotations 
44
55from  collections .abc  import  Sequence 
6- from  typing  import  TYPE_CHECKING , Callable 
7- from  typing_extensions  import  TypeGuard 
6+ from  typing  import  TYPE_CHECKING , Callable ,  cast 
7+ from  typing_extensions  import  NewType ,  TypeGuard 
88
99from  mypy  import  nodes 
1010from  mypy .maptype  import  map_instance_to_supertype 
1111from  mypy .typeops  import  make_simplified_union 
1212from  mypy .types  import  (
1313    AnyType ,
14+     CallableType ,
1415    Instance ,
1516    ParamSpecType ,
1617    ProperType ,
1718    TupleType ,
1819    Type ,
1920    TypedDictType ,
2021    TypeOfAny ,
22+     TypeVarId ,
2123    TypeVarTupleType ,
24+     TypeVarType ,
2225    UnionType ,
2326    UnpackType ,
27+     flatten_nested_tuples ,
2428    get_proper_type ,
2529)
2630
2731if  TYPE_CHECKING :
2832    from  mypy .infer  import  ArgumentInferContext 
2933
3034
35+ IterableType  =  NewType ("IterableType" , Instance )
36+ """Represents an instance of `Iterable[T]`.""" 
37+ 
38+ 
3139def  map_actuals_to_formals (
3240    actual_kinds : list [nodes .ArgKind ],
3341    actual_names : Sequence [str  |  None ] |  None ,
@@ -216,92 +224,41 @@ def expand_actual_type(
216224        original_actual  =  actual_type 
217225        actual_type  =  get_proper_type (actual_type )
218226        if  actual_kind  ==  nodes .ARG_STAR :
219-             if  isinstance (actual_type , UnionType ):
220-                 proper_types  =  [get_proper_type (t ) for  t  in  actual_type .items ]
221-                 # special case: union of equal sized tuples.  (e.g. `tuple[int, int] | tuple[None, None]`) 
222-                 if  is_equal_sized_tuples (proper_types ):
223-                     # transform union of tuples into a tuple of unions 
224-                     # e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None] 
225-                     tuple_args : list [Type ] =  [
226-                         make_simplified_union (items )
227-                         for  items  in  zip (* (t .items  for  t  in  proper_types ))
228-                     ]
229-                     actual_type  =  TupleType (
230-                         tuple_args ,
231-                         # use Iterable[A | B | C] as the fallback type 
232-                         fallback = Instance (
233-                             self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
234-                         ),
235-                     )
236-                 else :
237-                     # reinterpret all union items as iterable types (if possible) 
238-                     # and return the union of the iterable item types results. 
239-                     from  mypy .subtypes  import  is_subtype 
240- 
241-                     iterable_type  =  self .context .iterable_type 
242- 
243-                     def  as_iterable_type (t : Type ) ->  Type :
244-                         """Map a type to the iterable supertype if it is a subtype.""" 
245-                         p_t  =  get_proper_type (t )
246-                         if  isinstance (p_t , Instance ) and  is_subtype (t , iterable_type ):
247-                             return  map_instance_to_supertype (p_t , iterable_type .type )
248-                         if  isinstance (p_t , TupleType ):
249-                             # Convert tuple[A, B, C] to Iterable[A | B | C]. 
250-                             return  Instance (iterable_type .type , [make_simplified_union (p_t .items )])
251-                         return  t 
252- 
253-                     # create copies of self for each item in the union 
254-                     sub_expanders  =  [
255-                         ArgTypeExpander (context = self .context ) for  _  in  actual_type .items 
256-                     ]
257-                     for  expander  in  sub_expanders :
258-                         expander .tuple_index  =  int (self .tuple_index )
259-                         expander .kwargs_used  =  set (self .kwargs_used )
260- 
261-                     candidate_type  =  make_simplified_union (
262-                         [
263-                             e .expand_actual_type (
264-                                 as_iterable_type (item ),
265-                                 actual_kind ,
266-                                 formal_name ,
267-                                 formal_kind ,
268-                                 allow_unpack ,
269-                             )
270-                             for  e , item  in  zip (sub_expanders , actual_type .items )
271-                         ]
272-                     )
273-                     assert  all (expander  ==  sub_expanders [0 ] for  expander  in  sub_expanders )
274-                     # carry over the new state if all sub-expanders are the same state 
275-                     self .tuple_index  =  int (sub_expanders [0 ].tuple_index )
276-                     self .kwargs_used  =  set (sub_expanders [0 ].kwargs_used )
277-                     return  candidate_type 
278- 
279-             if  isinstance (actual_type , TypeVarTupleType ):
280-                 # This code path is hit when *Ts is passed to a callable and various 
281-                 # special-handling didn't catch this. The best thing we can do is to use 
282-                 # the upper bound. 
283-                 actual_type  =  get_proper_type (actual_type .upper_bound )
284-             if  isinstance (actual_type , Instance ) and  actual_type .args :
285-                 from  mypy .subtypes  import  is_subtype 
286- 
287-                 if  is_subtype (actual_type , self .context .iterable_type ):
288-                     return  map_instance_to_supertype (
289-                         actual_type , self .context .iterable_type .type 
290-                     ).args [0 ]
291-                 else :
292-                     # We cannot properly unpack anything other 
293-                     # than `Iterable` type with `*`. 
294-                     # Just return `Any`, other parts of code would raise 
295-                     # a different error for improper use. 
296-                     return  AnyType (TypeOfAny .from_error )
297-             elif  isinstance (actual_type , TupleType ):
227+             # parse *args as one of the following: 
228+             #    IterableType | TupleType | ParamSpecType | AnyType 
229+             star_args  =  self .parse_star_args_type (actual_type )
230+             # star_args = actual_type 
231+ 
232+             # print(f"expand_actual_type:  {actual_type=}  {star_args=}") 
233+ 
234+             # if isinstance(star_args, TypeVarTupleType): 
235+             #     # This code path is hit when *Ts is passed to a callable and various 
236+             #     # special-handling didn't catch this. The best thing we can do is to use 
237+             #     # the upper bound. 
238+             #     star_args = get_proper_type(star_args.upper_bound) 
239+             # if isinstance(star_args, Instance) and star_args.args: 
240+             #     from mypy.subtypes import is_subtype 
241+             # 
242+             #     if is_subtype(star_args, self.context.iterable_type): 
243+             #         return map_instance_to_supertype( 
244+             #             star_args, self.context.iterable_type.type 
245+             #         ).args[0] 
246+             #     else: 
247+             #         # We cannot properly unpack anything other 
248+             #         # than `Iterable` type with `*`. 
249+             #         # Just return `Any`, other parts of code would raise 
250+             #         # a different error for improper use. 
251+             #         return AnyType(TypeOfAny.from_error) 
252+             if  self .is_iterable_type (star_args ):
253+                 return  star_args .args [0 ]
254+             elif  isinstance (star_args , TupleType ):
298255                # Get the next tuple item of a tuple *arg. 
299-                 if  self .tuple_index  >=  len (actual_type .items ):
256+                 if  self .tuple_index  >=  len (star_args .items ):
300257                    # Exhausted a tuple -- continue to the next *args. 
301258                    self .tuple_index  =  1 
302259                else :
303260                    self .tuple_index  +=  1 
304-                 item  =  actual_type .items [self .tuple_index  -  1 ]
261+                 item  =  star_args .items [self .tuple_index  -  1 ]
305262                if  isinstance (item , UnpackType ) and  not  allow_unpack :
306263                    # An unpack item that doesn't have special handling, use upper bound as above. 
307264                    unpacked  =  get_proper_type (item .type )
@@ -315,9 +272,9 @@ def as_iterable_type(t: Type) -> Type:
315272                    )
316273                    item  =  fallback .args [0 ]
317274                return  item 
318-             elif  isinstance (actual_type , ParamSpecType ):
275+             elif  isinstance (star_args , ParamSpecType ):
319276                # ParamSpec is valid in *args but it can't be unpacked. 
320-                 return  actual_type 
277+                 return  star_args 
321278            else :
322279                return  AnyType (TypeOfAny .from_error )
323280        elif  actual_kind  ==  nodes .ARG_STAR2 :
@@ -349,19 +306,197 @@ def as_iterable_type(t: Type) -> Type:
349306            # No translation for other kinds -- 1:1 mapping. 
350307            return  original_actual 
351308
309+     def  is_iterable (self , typ : Type ) ->  bool :
310+         from  mypy .subtypes  import  is_subtype 
311+ 
312+         return  is_subtype (typ , self .context .iterable_type )
313+ 
314+     def  is_iterable_instance_subtype (self , typ : Type ) ->  TypeGuard [Instance ]:
315+         from  mypy .subtypes  import  is_subtype 
316+ 
317+         p_t  =  get_proper_type (typ )
318+         return  (
319+             isinstance (p_t , Instance )
320+             and  bool (p_t .args )
321+             and  is_subtype (p_t , self .context .iterable_type )
322+         )
323+ 
324+     def  is_iterable_type (self , typ : Type ) ->  TypeGuard [IterableType ]:
325+         """Check if the type is an Iterable[T] or a subtype of it.""" 
326+         p_t  =  get_proper_type (typ )
327+         return  isinstance (p_t , Instance ) and  p_t .type  ==  self .context .iterable_type .type 
328+ 
329+     def  as_iterable_type (self , typ : Type ) ->  IterableType  |  AnyType :
330+         """Reinterpret a type as Iterable[T], or return AnyType if not possible.""" 
331+         p_t  =  get_proper_type (typ )
332+         if  self .is_iterable_type (p_t ):
333+             return  p_t 
334+         elif  self .is_iterable_instance_subtype (p_t ):
335+             cls  =  self .context .iterable_type .type 
336+             return  cast (IterableType , map_instance_to_supertype (p_t , cls ))
337+         elif  isinstance (p_t , UnionType ):
338+             # If the type is a union, map each item to the iterable supertype. 
339+             # the return the combined iterable type Iterable[A] | Iterable[B] -> Iterable[A | B] 
340+             converted_types  =  [self .as_iterable_type (get_proper_type (item )) for  item  in  p_t .items ]
341+             # if an item could not be interpreted as Iterable[T], we return AnyType 
342+             if  all (self .is_iterable_type (it ) for  it  in  converted_types ):
343+                 # all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ] 
344+                 iterable_types  =  cast (list [IterableType ], converted_types )
345+                 arg  =  make_simplified_union ([it .args [0 ] for  it  in  iterable_types ])
346+                 return  self .make_iterable_type (arg )
347+             return  AnyType (TypeOfAny .from_error )
348+         elif  isinstance (p_t , TupleType ):
349+             # maps tuple[A, B, C] -> Iterable[A | B | C] 
350+             # note: proper_elements may contain UnpackType, for instance with 
351+             #   tuple[None, *tuple[None, ...]].. 
352+             proper_elements  =  [get_proper_type (t ) for  t  in  flatten_nested_tuples (p_t .items )]
353+             args : list [Type ] =  []
354+             for  p_e  in  proper_elements :
355+                 if  isinstance (p_e , UnpackType ):
356+                     r  =  self .as_iterable_type (p_e )
357+                     if  self .is_iterable_type (r ):
358+                         args .append (r .args [0 ])
359+                     else :
360+                         args .append (r )
361+                 else :
362+                     args .append (p_e )
363+             return  self .make_iterable_type (make_simplified_union (args ))
364+         if  isinstance (p_t , UnpackType ):
365+             return  self .as_iterable_type (p_t .type )
366+         if  isinstance (p_t , (TypeVarType , TypeVarTupleType )):
367+             return  self .as_iterable_type (p_t .upper_bound )
368+         # fallback: use the solver to reinterpret the type as Iterable[T] 
369+         if  self .is_iterable (p_t ):
370+             return  self ._solve_as_iterable (p_t )
371+         return  AnyType (TypeOfAny .from_error )
372+ 
373+     def  make_iterable_type (self , arg : Type ) ->  IterableType :
374+         value  =  Instance (self .context .iterable_type .type , [arg ])
375+         return  cast (IterableType , value )
376+ 
377+     def  parse_star_args_type (
378+         self , typ : Type 
379+     ) ->  TupleType  |  IterableType  |  ParamSpecType  |  AnyType :
380+         """Parse the type of a *args argument. 
381+ 
382+         Returns one TupleType, IterableType, ParamSpecType or AnyType. 
383+         """ 
384+         p_t  =  get_proper_type (typ )
385+         if  isinstance (p_t , (TupleType , ParamSpecType , AnyType )):
386+             # just return the type as-is 
387+             return  p_t 
388+         elif  isinstance (p_t , TypeVarTupleType ):
389+             return  self .parse_star_args_type (p_t .upper_bound )
390+         elif  isinstance (p_t , UnionType ):
391+             proper_items  =  [get_proper_type (t ) for  t  in  p_t .items ]
392+             # consider 2 cases: 
393+             # 1. Union of equal sized tuples, e.g. tuple[A, B] | tuple[None, None] 
394+             #    In this case transform union of same-sized tuples into a tuple of unions 
395+             #    e.g. tuple[A, B] | tuple[None, None] -> tuple[A | None, B | None] 
396+             if  is_equal_sized_tuples (proper_items ):
397+ 
398+                 tuple_args : list [Type ] =  [
399+                     make_simplified_union (items ) for  items  in  zip (* (t .items  for  t  in  proper_items ))
400+                 ]
401+                 actual_type  =  TupleType (
402+                     tuple_args ,
403+                     # use Iterable[A | B | C] as the fallback type 
404+                     fallback = Instance (
405+                         self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
406+                     ),
407+                 )
408+                 return  actual_type 
409+             # 2. Union of iterable types, e.g. Iterable[A] | Iterable[B] 
410+             #    In this case return Iterable[A | B] 
411+             #    Note that this covers unions of differently sized tuples as well. 
412+             else :
413+                 converted_types  =  [self .as_iterable_type (p_i ) for  p_i  in  proper_items ]
414+                 if  all (self .is_iterable_type (it ) for  it  in  converted_types ):
415+                     # all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ] 
416+                     iterables  =  cast (list [IterableType ], converted_types )
417+                     arg  =  make_simplified_union ([it .args [0 ] for  it  in  iterables ])
418+                     return  self .make_iterable_type (arg )
419+                 else :
420+                     # some items in the union are not iterable, return AnyType 
421+                     return  AnyType (TypeOfAny .from_error )
422+         elif  self .is_iterable_type (parsed  :=  self .as_iterable_type (p_t )):
423+             # in all other cases, we try to reinterpret the type as Iterable[T] 
424+             return  parsed 
425+         return  AnyType (TypeOfAny .from_error )
426+ 
427+     def  _solve_as_iterable (self , typ : Type ) ->  IterableType  |  AnyType :
428+         r"""Use the solver to cast a type as Iterable[T]. 
429+ 
430+         Returns the type as-is if solving fails. 
431+         """ 
432+         from  mypy .constraints  import  infer_constraints_for_callable 
433+         from  mypy .nodes  import  ARG_POS 
434+         from  mypy .solve  import  solve_constraints 
435+ 
436+         iterable_kind  =  self .context .iterable_type .type 
437+ 
438+         # We first create an upcast function: 
439+         #    def [T] (Iterable[T]) -> Iterable[T]: ... 
440+         # and then solve for T, given the input type as the argument. 
441+         T  =  TypeVarType (
442+             "T" ,
443+             "T" ,
444+             TypeVarId (- 1 ),
445+             values = [],
446+             upper_bound = AnyType (TypeOfAny .special_form ),
447+             default = AnyType (TypeOfAny .special_form ),
448+         )
449+         target  =  Instance (iterable_kind , [T ])
450+ 
451+         upcast_callable  =  CallableType (
452+             variables = [T ],
453+             arg_types = [target ],
454+             arg_kinds = [ARG_POS ],
455+             arg_names = [None ],
456+             ret_type = T ,
457+             fallback = self .context .function_type ,
458+         )
459+         constraints  =  infer_constraints_for_callable (
460+             upcast_callable , [typ ], [ARG_POS ], [None ], [[0 ]], context = self .context 
461+         )
462+ 
463+         (sol ,), _  =  solve_constraints ([T ], constraints )
464+ 
465+         if  sol  is  None :  # solving failed, return AnyType fallback 
466+             return  AnyType (TypeOfAny .from_error )
467+         return  self .make_iterable_type (sol )
468+ 
352469
353470def  is_equal_sized_tuples (types : Sequence [ProperType ]) ->  TypeGuard [Sequence [TupleType ]]:
354-     """Check if all types are tuples of the same size.""" 
471+     """Check if all types are tuples of the same size. 
472+ 
473+     We use `flatten_nested_tuples` to deal with nested tuples. 
474+     Note that the result may still contain 
475+     """ 
355476    if  not  types :
356477        return  True 
357478
358479    iterator  =  iter (types )
359-     first  =  next (iterator )
360-     if  not  isinstance (first , TupleType ):
480+     typ  =  next (iterator )
481+     if  not  isinstance (typ , TupleType ):
482+         return  False 
483+     flattened_elements  =  flatten_nested_tuples (typ .items )
484+     if  any (
485+         isinstance (get_proper_type (member ), (UnpackType , TypeVarTupleType ))
486+         for  member  in  flattened_elements 
487+     ):
488+         # this can happen e.g. with tuple[int, *tuple[int, ...], int] 
361489        return  False 
362-     size  =  first . length ( )
490+     size  =  len ( flattened_elements )
363491
364-     for  item  in  iterator :
365-         if  not  isinstance (item , TupleType ) or  item .length () !=  size :
492+     for  typ  in  iterator :
493+         if  not  isinstance (typ , TupleType ):
494+             return  False 
495+         flattened_elements  =  flatten_nested_tuples (typ .items )
496+         if  len (flattened_elements ) !=  size  or  any (
497+             isinstance (get_proper_type (member ), (UnpackType , TypeVarTupleType ))
498+             for  member  in  flattened_elements 
499+         ):
500+             # this can happen e.g. with tuple[int, *tuple[int, ...], int] 
366501            return  False 
367502    return  True 
0 commit comments