@@ -75,6 +75,7 @@ def map_actuals_to_formals(
7575 proper_types := [get_proper_type (t ) for t in actualt .items ]
7676 )
7777 ):
78+ # pick an arbitrary member
7879 actualt = proper_types [0 ]
7980 if isinstance (actualt , TupleType ):
8081 # A tuple actual maps to a fixed number of formals.
@@ -193,15 +194,6 @@ def __init__(self, context: ArgumentInferContext) -> None:
193194 # Type context for `*` and `**` arg kinds.
194195 self .context = context
195196
196- def __eq__ (self , other : object ) -> bool :
197- if isinstance (other , ArgTypeExpander ):
198- return (
199- self .tuple_index == other .tuple_index
200- and self .kwargs_used == other .kwargs_used
201- and self .context == other .context
202- )
203- return NotImplemented
204-
205197 def expand_actual_type (
206198 self ,
207199 actual_type : Type ,
@@ -227,29 +219,8 @@ def expand_actual_type(
227219 # parse *args as one of the following:
228220 # IterableType | TupleType | ParamSpecType | AnyType
229221 star_args = self .parse_star_args_type (actual_type )
230- # star_args = actual_type
231-
232- # print(f"expand_actual_type: {actual_type=} {star_args=}")
233-
234- # if isinstance(star_args, TypeVarTupleType):
235- # # This code path is hit when *Ts is passed to a callable and various
236- # # special-handling didn't catch this. The best thing we can do is to use
237- # # the upper bound.
238- # star_args = get_proper_type(star_args.upper_bound)
239- # if isinstance(star_args, Instance) and star_args.args:
240- # from mypy.subtypes import is_subtype
241- #
242- # if is_subtype(star_args, self.context.iterable_type):
243- # return map_instance_to_supertype(
244- # star_args, self.context.iterable_type.type
245- # ).args[0]
246- # else:
247- # # We cannot properly unpack anything other
248- # # than `Iterable` type with `*`.
249- # # Just return `Any`, other parts of code would raise
250- # # a different error for improper use.
251- # return AnyType(TypeOfAny.from_error)
252- if self .is_iterable_type (star_args ):
222+
223+ if self .is_iterable_instance_type (star_args ):
253224 return star_args .args [0 ]
254225 elif isinstance (star_args , TupleType ):
255226 # Get the next tuple item of a tuple *arg.
@@ -321,30 +292,75 @@ def is_iterable_instance_subtype(self, typ: Type) -> TypeGuard[Instance]:
321292 and is_subtype (p_t , self .context .iterable_type )
322293 )
323294
324- def is_iterable_type (self , typ : Type ) -> TypeGuard [IterableType ]:
295+ def is_iterable_instance_type (self , typ : Type ) -> TypeGuard [IterableType ]:
325296 """Check if the type is an Iterable[T] or a subtype of it."""
326297 p_t = get_proper_type (typ )
327298 return isinstance (p_t , Instance ) and p_t .type == self .context .iterable_type .type
328299
300+ def _make_iterable_instance_type (self , arg : Type ) -> IterableType :
301+ value = Instance (self .context .iterable_type .type , [arg ])
302+ return cast (IterableType , value )
303+
304+ def _solve_as_iterable (self , typ : Type ) -> IterableType | AnyType :
305+ r"""Use the solver to cast a type as Iterable[T].
306+
307+ Returns `AnyType` if solving fails.
308+ """
309+ from mypy .constraints import infer_constraints_for_callable
310+ from mypy .nodes import ARG_POS
311+ from mypy .solve import solve_constraints
312+
313+ iterable_kind = self .context .iterable_type .type
314+
315+ # We first create an upcast function:
316+ # def [T] (Iterable[T]) -> Iterable[T]: ...
317+ # and then solve for T, given the input type as the argument.
318+ T = TypeVarType (
319+ "T" ,
320+ "T" ,
321+ TypeVarId (- 1 ),
322+ values = [],
323+ upper_bound = AnyType (TypeOfAny .special_form ),
324+ default = AnyType (TypeOfAny .special_form ),
325+ )
326+ target = Instance (iterable_kind , [T ])
327+
328+ upcast_callable = CallableType (
329+ variables = [T ],
330+ arg_types = [target ],
331+ arg_kinds = [ARG_POS ],
332+ arg_names = [None ],
333+ ret_type = T ,
334+ fallback = self .context .function_type ,
335+ )
336+ constraints = infer_constraints_for_callable (
337+ upcast_callable , [typ ], [ARG_POS ], [None ], [[0 ]], context = self .context
338+ )
339+
340+ (sol ,), _ = solve_constraints ([T ], constraints )
341+
342+ if sol is None : # solving failed, return AnyType fallback
343+ return AnyType (TypeOfAny .from_error )
344+ return self ._make_iterable_instance_type (sol )
345+
329346 def as_iterable_type (self , typ : Type ) -> IterableType | AnyType :
330347 """Reinterpret a type as Iterable[T], or return AnyType if not possible."""
331348 p_t = get_proper_type (typ )
332- if self .is_iterable_type (p_t ):
349+ if self .is_iterable_instance_type (p_t ) or isinstance ( p_t , AnyType ):
333350 return p_t
334- elif self .is_iterable_instance_subtype (p_t ):
335- cls = self .context .iterable_type .type
336- return cast (IterableType , map_instance_to_supertype (p_t , cls ))
337351 elif isinstance (p_t , UnionType ):
338352 # If the type is a union, map each item to the iterable supertype.
339353 # the return the combined iterable type Iterable[A] | Iterable[B] -> Iterable[A | B]
340354 converted_types = [self .as_iterable_type (get_proper_type (item )) for item in p_t .items ]
341- # if an item could not be interpreted as Iterable[T], we return AnyType
342- if all (self .is_iterable_type (it ) for it in converted_types ):
355+
356+ if any (not self .is_iterable_instance_type (it ) for it in converted_types ):
357+ # if any item could not be interpreted as Iterable[T], we return AnyType
358+ return AnyType (TypeOfAny .from_error )
359+ else :
343360 # all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
344361 iterable_types = cast (list [IterableType ], converted_types )
345362 arg = make_simplified_union ([it .args [0 ] for it in iterable_types ])
346- return self .make_iterable_type (arg )
347- return AnyType (TypeOfAny .from_error )
363+ return self ._make_iterable_instance_type (arg )
348364 elif isinstance (p_t , TupleType ):
349365 # maps tuple[A, B, C] -> Iterable[A | B | C]
350366 # note: proper_elements may contain UnpackType, for instance with
@@ -354,26 +370,24 @@ def as_iterable_type(self, typ: Type) -> IterableType | AnyType:
354370 for p_e in proper_elements :
355371 if isinstance (p_e , UnpackType ):
356372 r = self .as_iterable_type (p_e )
357- if self .is_iterable_type (r ):
373+ if self .is_iterable_instance_type (r ):
358374 args .append (r .args [0 ])
359375 else :
376+ # this *should* never happen
360377 args .append (r )
361378 else :
362379 args .append (p_e )
363- return self .make_iterable_type (make_simplified_union (args ))
364- if isinstance (p_t , UnpackType ):
380+ return self ._make_iterable_instance_type (make_simplified_union (args ))
381+ elif isinstance (p_t , UnpackType ):
365382 return self .as_iterable_type (p_t .type )
366- if isinstance (p_t , (TypeVarType , TypeVarTupleType )):
383+ elif isinstance (p_t , (TypeVarType , TypeVarTupleType )):
367384 return self .as_iterable_type (p_t .upper_bound )
368- # fallback: use the solver to reinterpret the type as Iterable[T]
369- if self .is_iterable (p_t ):
385+ elif self .is_iterable (p_t ):
386+ # TODO: add a 'fast path' (needs measurement) that uses the map_instance_to_supertype
387+ # mechanism? (Only if it works: gh-19662)
370388 return self ._solve_as_iterable (p_t )
371389 return AnyType (TypeOfAny .from_error )
372390
373- def make_iterable_type (self , arg : Type ) -> IterableType :
374- value = Instance (self .context .iterable_type .type , [arg ])
375- return cast (IterableType , value )
376-
377391 def parse_star_args_type (
378392 self , typ : Type
379393 ) -> TupleType | IterableType | ParamSpecType | AnyType :
@@ -411,61 +425,19 @@ def parse_star_args_type(
411425 # Note that this covers unions of differently sized tuples as well.
412426 else :
413427 converted_types = [self .as_iterable_type (p_i ) for p_i in proper_items ]
414- if all (self .is_iterable_type (it ) for it in converted_types ):
428+ if all (self .is_iterable_instance_type (it ) for it in converted_types ):
415429 # all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
416430 iterables = cast (list [IterableType ], converted_types )
417431 arg = make_simplified_union ([it .args [0 ] for it in iterables ])
418- return self .make_iterable_type (arg )
432+ return self ._make_iterable_instance_type (arg )
419433 else :
420434 # some items in the union are not iterable, return AnyType
421435 return AnyType (TypeOfAny .from_error )
422- elif self .is_iterable_type (parsed := self .as_iterable_type (p_t )):
436+ elif self .is_iterable_instance_type (parsed := self .as_iterable_type (p_t )):
423437 # in all other cases, we try to reinterpret the type as Iterable[T]
424438 return parsed
425439 return AnyType (TypeOfAny .from_error )
426440
427- def _solve_as_iterable (self , typ : Type ) -> IterableType | AnyType :
428- r"""Use the solver to cast a type as Iterable[T].
429-
430- Returns the type as-is if solving fails.
431- """
432- from mypy .constraints import infer_constraints_for_callable
433- from mypy .nodes import ARG_POS
434- from mypy .solve import solve_constraints
435-
436- iterable_kind = self .context .iterable_type .type
437-
438- # We first create an upcast function:
439- # def [T] (Iterable[T]) -> Iterable[T]: ...
440- # and then solve for T, given the input type as the argument.
441- T = TypeVarType (
442- "T" ,
443- "T" ,
444- TypeVarId (- 1 ),
445- values = [],
446- upper_bound = AnyType (TypeOfAny .special_form ),
447- default = AnyType (TypeOfAny .special_form ),
448- )
449- target = Instance (iterable_kind , [T ])
450-
451- upcast_callable = CallableType (
452- variables = [T ],
453- arg_types = [target ],
454- arg_kinds = [ARG_POS ],
455- arg_names = [None ],
456- ret_type = T ,
457- fallback = self .context .function_type ,
458- )
459- constraints = infer_constraints_for_callable (
460- upcast_callable , [typ ], [ARG_POS ], [None ], [[0 ]], context = self .context
461- )
462-
463- (sol ,), _ = solve_constraints ([T ], constraints )
464-
465- if sol is None : # solving failed, return AnyType fallback
466- return AnyType (TypeOfAny .from_error )
467- return self .make_iterable_type (sol )
468-
469441
470442def is_equal_sized_tuples (types : Sequence [ProperType ]) -> TypeGuard [Sequence [TupleType ]]:
471443 """Check if all types are tuples of the same size.
0 commit comments