Skip to content

Commit da03a38

Browse files
committed
Modify approach to collectively infer union instead of forcing.
1 parent 6cfd442 commit da03a38

File tree

1 file changed

+44
-20
lines changed

1 file changed

+44
-20
lines changed

mypy/checkexpr.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4990,7 +4990,13 @@ def apply_type_arguments_to_callable(
49904990

49914991
def visit_list_expr(self, e: ListExpr) -> Type:
49924992
"""Type check a list expression [...]."""
4993-
return self.check_lst_expr(e, "builtins.list", "<list>")
4993+
if len(e.items) > 0:
4994+
item_types = []
4995+
for item in e.items:
4996+
item_type = self.accept(item)
4997+
item_types.append(item_type)
4998+
element_type = make_simplified_union(item_types)
4999+
return self.chk.named_generic_type("builtins.list", [element_type])
49945000

49955001
def visit_set_expr(self, e: SetExpr) -> Type:
49965002
return self.check_lst_expr(e, "builtins.set", "<set>")
@@ -5017,11 +5023,10 @@ def fast_container_type(
50175023
values: list[Type] = []
50185024
for item in e.items:
50195025
if isinstance(item, StarExpr):
5020-
# fallback to slow path
50215026
self.resolved_type[e] = NoneType()
50225027
return None
50235028
values.append(self.accept(item))
5024-
vt = join.join_type_list(values)
5029+
vt = make_simplified_union(values)
50255030
if not allow_fast_container_literal(vt):
50265031
self.resolved_type[e] = NoneType()
50275032
return None
@@ -5064,9 +5069,6 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
50645069
#!BUG this forces all list exprs to be forced (we don't want that)
50655070

50665071
# Translate into type checking a generic function call.
5067-
# Used for list and set expressions, as well as for tuples
5068-
# containing star expressions that don't refer to a
5069-
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
50705072
tv = TypeVarType(
50715073
"T",
50725074
"T",
@@ -5075,6 +5077,31 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
50755077
upper_bound=self.object_type(),
50765078
default=AnyType(TypeOfAny.from_omitted_generics),
50775079
)
5080+
5081+
# get all elements and join types from list items
5082+
if isinstance(e, ListExpr):
5083+
item_types: list[Type] = []
5084+
for item in e.items:
5085+
if isinstance(item, StarExpr):
5086+
starred_type = self.accept(item.expr)
5087+
starred_type = get_proper_type(starred_type)
5088+
if isinstance(starred_type, TupleType):
5089+
item_types.extend(starred_type.items)
5090+
else:
5091+
item_types.append(starred_type)
5092+
else:
5093+
item_types.append(self.accept(item))
5094+
unified_type = join.join_type_list(item_types)
5095+
if not isinstance(unified_type, (AnyType, UninhabitedType)):
5096+
tv = TypeVarType(
5097+
"T",
5098+
"T",
5099+
id=TypeVarId(-1, namespace="<lst>"),
5100+
values=[],
5101+
upper_bound=unified_type,
5102+
default=unified_type,
5103+
)
5104+
50785105
constructor = CallableType(
50795106
[tv],
50805107
[nodes.ARG_STAR],
@@ -5084,6 +5111,7 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
50845111
name=tag,
50855112
variables=[tv],
50865113
)
5114+
50875115
out = self.check_call(
50885116
constructor,
50895117
[(i.expr if isinstance(i, StarExpr) else i) for i in e.items],
@@ -5776,21 +5804,17 @@ def check_for_comp(self, e: GeneratorExpr | DictionaryComprehension) -> None:
57765804
_, sequence_type = self.chk.analyze_async_iterable_item_type(sequence)
57775805
else:
57785806
_, sequence_type = self.chk.analyze_iterable_item_type(sequence)
5807+
sequence_type = get_proper_type(sequence_type)
5808+
if isinstance(sequence_type, Instance):
5809+
if sequence_type.type.fullname == "builtins.list":
5810+
item_types = []
5811+
if isinstance(sequence, ListExpr):
5812+
for item in sequence.items:
5813+
item_type = self.accept(item)
5814+
item_types.append(item_type)
5815+
if item_types:
5816+
sequence_type = make_simplified_union(item_types)
57795817
self.chk.analyze_index_variables(index, sequence_type, True, e)
5780-
for condition in conditions:
5781-
self.accept(condition)
5782-
5783-
# values are only part of the comprehension when all conditions are true
5784-
true_map, false_map = self.chk.find_isinstance_check(condition)
5785-
5786-
if true_map:
5787-
self.chk.push_type_map(true_map)
5788-
5789-
if codes.REDUNDANT_EXPR in self.chk.options.enabled_error_codes:
5790-
if true_map is None:
5791-
self.msg.redundant_condition_in_comprehension(False, condition)
5792-
elif false_map is None:
5793-
self.msg.redundant_condition_in_comprehension(True, condition)
57945818

57955819
def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = False) -> Type:
57965820
self.accept(e.cond)

0 commit comments

Comments
 (0)