44
55from collections .abc import Sequence
66from typing import TYPE_CHECKING , Callable , cast
7- from typing_extensions import NewType , TypeGuard
7+ from typing_extensions import NewType , TypeGuard , TypeIs
88
99from mypy import nodes
1010from mypy .maptype import map_instance_to_supertype
@@ -278,11 +278,12 @@ def expand_actual_type(
278278 return original_actual
279279
280280 def is_iterable (self , typ : Type ) -> bool :
281+ """Check if the type is an iterable, i.e. implements the Iterable Protocol."""
281282 from mypy .subtypes import is_subtype
282283
283284 return is_subtype (typ , self .context .iterable_type )
284285
285- def is_iterable_instance_type (self , typ : Type ) -> TypeGuard [IterableType ]:
286+ def is_iterable_instance_type (self , typ : Type ) -> TypeIs [IterableType ]:
286287 """Check if the type is an Iterable[T]."""
287288 p_t = get_proper_type (typ )
288289 return isinstance (p_t , Instance ) and p_t .type == self .context .iterable_type .type
@@ -300,8 +301,6 @@ def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
300301 from mypy .nodes import ARG_POS
301302 from mypy .solve import solve_constraints
302303
303- iterable_kind = self .context .iterable_type .type
304-
305304 # We first create an upcast function:
306305 # def [T] (Iterable[T]) -> Iterable[T]: ...
307306 # and then solve for T, given the input type as the argument.
@@ -310,21 +309,20 @@ def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
310309 "T" ,
311310 TypeVarId (- 1 ),
312311 values = [],
313- upper_bound = AnyType (TypeOfAny .special_form ),
314- default = AnyType (TypeOfAny .special_form ),
312+ upper_bound = AnyType (TypeOfAny .from_omitted_generics ),
313+ default = AnyType (TypeOfAny .from_omitted_generics ),
315314 )
316- target = Instance (iterable_kind , [T ])
317-
315+ target = self ._make_iterable_instance_type (T )
318316 upcast_callable = CallableType (
319317 variables = [T ],
320318 arg_types = [target ],
321319 arg_kinds = [ARG_POS ],
322320 arg_names = [None ],
323- ret_type = T ,
321+ ret_type = target ,
324322 fallback = self .context .function_type ,
325323 )
326324 constraints = infer_constraints_for_callable (
327- upcast_callable , [typ ], [ARG_POS ], [None ], [[0 ]], context = self .context
325+ upcast_callable , [typ ], [ARG_POS ], [None ], [[0 ]], self .context
328326 )
329327
330328 (sol ,), _ = solve_constraints ([T ], constraints )
@@ -334,7 +332,11 @@ def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
334332 return self ._make_iterable_instance_type (sol )
335333
336334 def as_iterable_type (self , typ : Type ) -> IterableType | AnyType :
337- """Reinterpret a type as Iterable[T], or return AnyType if not possible."""
335+ """Reinterpret a type as Iterable[T], or return AnyType if not possible.
336+
337+ This function specially handles certain types like UnionType, TupleType, and UnpackType.
338+ Otherwise, the upcasting is performed using the solver.
339+ """
338340 p_t = get_proper_type (typ )
339341 if self .is_iterable_instance_type (p_t ) or isinstance (p_t , AnyType ):
340342 return p_t
@@ -386,8 +388,8 @@ def parse_star_args_type(
386388 ) -> TupleType | IterableType | ParamSpecType | AnyType :
387389 """Parse the type of a ``*args`` argument.
388390
389- Returns one of TupleType, IterableType, ParamSpecType,
390- or AnyType(TypeOfAny.from_error) if the type cannot be parsed or is invalid.
391+ Returns one of TupleType, IterableType, ParamSpecType or AnyType.
392+ Returns AnyType(TypeOfAny.from_error) if the type cannot be parsed or is invalid.
391393 """
392394 p_t = get_proper_type (typ )
393395 if isinstance (p_t , (TupleType , ParamSpecType , AnyType )):
0 commit comments