Skip to content

Commit e169d3e

Browse files
address review
1 parent 0a6d757 commit e169d3e

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

mypy/argmap.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from collections.abc import Sequence
66
from typing import TYPE_CHECKING, Callable, cast
7-
from typing_extensions import NewType, TypeGuard
7+
from typing_extensions import NewType, TypeGuard, TypeIs
88

99
from mypy import nodes
1010
from mypy.maptype import map_instance_to_supertype
@@ -278,11 +278,12 @@ def expand_actual_type(
278278
return original_actual
279279

280280
def is_iterable(self, typ: Type) -> bool:
281+
"""Check if the type is an iterable, i.e. implements the Iterable Protocol."""
281282
from mypy.subtypes import is_subtype
282283

283284
return is_subtype(typ, self.context.iterable_type)
284285

285-
def is_iterable_instance_type(self, typ: Type) -> TypeGuard[IterableType]:
286+
def is_iterable_instance_type(self, typ: Type) -> TypeIs[IterableType]:
286287
"""Check if the type is an Iterable[T]."""
287288
p_t = get_proper_type(typ)
288289
return isinstance(p_t, Instance) and p_t.type == self.context.iterable_type.type
@@ -300,8 +301,6 @@ def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
300301
from mypy.nodes import ARG_POS
301302
from mypy.solve import solve_constraints
302303

303-
iterable_kind = self.context.iterable_type.type
304-
305304
# We first create an upcast function:
306305
# def [T] (Iterable[T]) -> Iterable[T]: ...
307306
# and then solve for T, given the input type as the argument.
@@ -310,21 +309,20 @@ def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
310309
"T",
311310
TypeVarId(-1),
312311
values=[],
313-
upper_bound=AnyType(TypeOfAny.special_form),
314-
default=AnyType(TypeOfAny.special_form),
312+
upper_bound=AnyType(TypeOfAny.from_omitted_generics),
313+
default=AnyType(TypeOfAny.from_omitted_generics),
315314
)
316-
target = Instance(iterable_kind, [T])
317-
315+
target = self._make_iterable_instance_type(T)
318316
upcast_callable = CallableType(
319317
variables=[T],
320318
arg_types=[target],
321319
arg_kinds=[ARG_POS],
322320
arg_names=[None],
323-
ret_type=T,
321+
ret_type=target,
324322
fallback=self.context.function_type,
325323
)
326324
constraints = infer_constraints_for_callable(
327-
upcast_callable, [typ], [ARG_POS], [None], [[0]], context=self.context
325+
upcast_callable, [typ], [ARG_POS], [None], [[0]], self.context
328326
)
329327

330328
(sol,), _ = solve_constraints([T], constraints)
@@ -334,7 +332,11 @@ def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
334332
return self._make_iterable_instance_type(sol)
335333

336334
def as_iterable_type(self, typ: Type) -> IterableType | AnyType:
337-
"""Reinterpret a type as Iterable[T], or return AnyType if not possible."""
335+
"""Reinterpret a type as Iterable[T], or return AnyType if not possible.
336+
337+
This function specially handles certain types like UnionType, TupleType, and UnpackType.
338+
Otherwise, the upcasting is performed using the solver.
339+
"""
338340
p_t = get_proper_type(typ)
339341
if self.is_iterable_instance_type(p_t) or isinstance(p_t, AnyType):
340342
return p_t
@@ -386,8 +388,8 @@ def parse_star_args_type(
386388
) -> TupleType | IterableType | ParamSpecType | AnyType:
387389
"""Parse the type of a ``*args`` argument.
388390
389-
Returns one of TupleType, IterableType, ParamSpecType,
390-
or AnyType(TypeOfAny.from_error) if the type cannot be parsed or is invalid.
391+
Returns one of TupleType, IterableType, ParamSpecType or AnyType.
392+
Returns AnyType(TypeOfAny.from_error) if the type cannot be parsed or is invalid.
391393
"""
392394
p_t = get_proper_type(typ)
393395
if isinstance(p_t, (TupleType, ParamSpecType, AnyType)):

0 commit comments

Comments
 (0)