Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 72 additions & 20 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4977,7 +4977,13 @@ def apply_type_arguments_to_callable(

def visit_list_expr(self, e: ListExpr) -> Type:
"""Type check a list expression [...]."""
return self.check_lst_expr(e, "builtins.list", "<list>")
if len(e.items) > 0:
item_types = []
for item in e.items:
item_type = self.accept(item)
item_types.append(item_type)
element_type = make_simplified_union(item_types)
return self.chk.named_generic_type("builtins.list", [element_type])

def visit_set_expr(self, e: SetExpr) -> Type:
return self.check_lst_expr(e, "builtins.set", "<set>")
Expand All @@ -5004,28 +5010,52 @@ def fast_container_type(
values: list[Type] = []
for item in e.items:
if isinstance(item, StarExpr):
# fallback to slow path
self.resolved_type[e] = NoneType()
return None
values.append(self.accept(item))
vt = join.join_type_list(values)
vt = make_simplified_union(values)
if not allow_fast_container_literal(vt):
self.resolved_type[e] = NoneType()
return None
ct = self.chk.named_generic_type(container_fullname, [vt])
self.resolved_type[e] = ct
return ct

def infer_item_type(self, item_types: List[Type]) -> Type:
"""Infer the item type for a list based on its elements."""
joined_type = self.chk.join_types(*item_types)
proper_joined = get_proper_type(joined_type)
if (
isinstance(proper_joined, Instance)
and proper_joined.type.fullname == "builtins.object"
and len(set(map(type, item_types))) > 1
):
# if we can't find a common supertype other than 'object',
# use a Union of the item types
return UnionType.make_simplified_union(item_types)
else:
# otherwise just use the common supertype
return joined_type

def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: str) -> Type:
# fast path
t = self.fast_container_type(e, fullname)
if t:
return t

# try to infer directly
if isinstance(e, ListExpr):
item_types = [self.accept(item) for item in e.items]
if not item_types:
item_type = AnyType(TypeOfAny.from_empty_collection)
else:
# use all the element types
item_type = self.infer_item_type(item_types)
return self.chk.named_generic_type(fullname, [item_type])

#!BUG this forces all list exprs to be forced (we don't want that)

# Translate into type checking a generic function call.
# Used for list and set expressions, as well as for tuples
# containing star expressions that don't refer to a
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
tv = TypeVarType(
"T",
"T",
Expand All @@ -5034,6 +5064,31 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)

# get all elements and join types from list items
if isinstance(e, ListExpr):
item_types: list[Type] = []
for item in e.items:
if isinstance(item, StarExpr):
starred_type = self.accept(item.expr)
starred_type = get_proper_type(starred_type)
if isinstance(starred_type, TupleType):
item_types.extend(starred_type.items)
else:
item_types.append(starred_type)
else:
item_types.append(self.accept(item))
unified_type = join.join_type_list(item_types)
if not isinstance(unified_type, (AnyType, UninhabitedType)):
tv = TypeVarType(
"T",
"T",
id=TypeVarId(-1, namespace="<lst>"),
values=[],
upper_bound=unified_type,
default=unified_type,
)

constructor = CallableType(
[tv],
[nodes.ARG_STAR],
Expand All @@ -5043,6 +5098,7 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
name=tag,
variables=[tv],
)

out = self.check_call(
constructor,
[(i.expr if isinstance(i, StarExpr) else i) for i in e.items],
Expand Down Expand Up @@ -5735,21 +5791,17 @@ def check_for_comp(self, e: GeneratorExpr | DictionaryComprehension) -> None:
_, sequence_type = self.chk.analyze_async_iterable_item_type(sequence)
else:
_, sequence_type = self.chk.analyze_iterable_item_type(sequence)
sequence_type = get_proper_type(sequence_type)
if isinstance(sequence_type, Instance):
if sequence_type.type.fullname == "builtins.list":
item_types = []
if isinstance(sequence, ListExpr):
for item in sequence.items:
item_type = self.accept(item)
item_types.append(item_type)
if item_types:
sequence_type = make_simplified_union(item_types)
self.chk.analyze_index_variables(index, sequence_type, True, e)
for condition in conditions:
self.accept(condition)

# values are only part of the comprehension when all conditions are true
true_map, false_map = self.chk.find_isinstance_check(condition)

if true_map:
self.chk.push_type_map(true_map)

if codes.REDUNDANT_EXPR in self.chk.options.enabled_error_codes:
if true_map is None:
self.msg.redundant_condition_in_comprehension(False, condition)
elif false_map is None:
self.msg.redundant_condition_in_comprehension(True, condition)

def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = False) -> Type:
self.accept(e.cond)
Expand Down
20 changes: 20 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2452,3 +2452,23 @@ x + T # E: Unsupported left operand type for + ("int")
T() # E: "TypeVar" not callable
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testListComprehensionWithTupleUnionTypeGenerator]
class A: pass
class B: pass
a = A()
b = B()
l3: list[A | B] = [x for x in (a, b)]
l3: list[A | B] = [x for x in [a, b]]

[case testListComprehensionWithListUnionTypeGenerator]
a = A()
b = "foo"
l3: list[A | str] = [x for x in (a, b)]
l3: list[A | str] = [x for x in [a, b]]

[case testListComprehensionWithImmutableTypeUnionTypeGenerator]
a = 3.0
b = 3
l3: list[float | int] = [x for x in (a, b)]
l3: list[float | int] = [x for x in [a, b]]
Loading