Skip to content

Commit f5cc357

Browse files
committed
Modify approach to collectively infer union instead of forcing.
1 parent 0f28a5f commit f5cc357

File tree

1 file changed

+44
-20
lines changed

1 file changed

+44
-20
lines changed

mypy/checkexpr.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4977,7 +4977,13 @@ def apply_type_arguments_to_callable(
49774977

49784978
def visit_list_expr(self, e: ListExpr) -> Type:
49794979
"""Type check a list expression [...]."""
4980-
return self.check_lst_expr(e, "builtins.list", "<list>")
4980+
if len(e.items) > 0:
4981+
item_types = []
4982+
for item in e.items:
4983+
item_type = self.accept(item)
4984+
item_types.append(item_type)
4985+
element_type = make_simplified_union(item_types)
4986+
return self.chk.named_generic_type("builtins.list", [element_type])
49814987

49824988
def visit_set_expr(self, e: SetExpr) -> Type:
49834989
return self.check_lst_expr(e, "builtins.set", "<set>")
@@ -5004,11 +5010,10 @@ def fast_container_type(
50045010
values: list[Type] = []
50055011
for item in e.items:
50065012
if isinstance(item, StarExpr):
5007-
# fallback to slow path
50085013
self.resolved_type[e] = NoneType()
50095014
return None
50105015
values.append(self.accept(item))
5011-
vt = join.join_type_list(values)
5016+
vt = make_simplified_union(values)
50125017
if not allow_fast_container_literal(vt):
50135018
self.resolved_type[e] = NoneType()
50145019
return None
@@ -5051,9 +5056,6 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
50515056
#!BUG this forces all list exprs to be forced (we don't want that)
50525057

50535058
# Translate into type checking a generic function call.
5054-
# Used for list and set expressions, as well as for tuples
5055-
# containing star expressions that don't refer to a
5056-
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
50575059
tv = TypeVarType(
50585060
"T",
50595061
"T",
@@ -5062,6 +5064,31 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
50625064
upper_bound=self.object_type(),
50635065
default=AnyType(TypeOfAny.from_omitted_generics),
50645066
)
5067+
5068+
# get all elements and join types from list items
5069+
if isinstance(e, ListExpr):
5070+
item_types: list[Type] = []
5071+
for item in e.items:
5072+
if isinstance(item, StarExpr):
5073+
starred_type = self.accept(item.expr)
5074+
starred_type = get_proper_type(starred_type)
5075+
if isinstance(starred_type, TupleType):
5076+
item_types.extend(starred_type.items)
5077+
else:
5078+
item_types.append(starred_type)
5079+
else:
5080+
item_types.append(self.accept(item))
5081+
unified_type = join.join_type_list(item_types)
5082+
if not isinstance(unified_type, (AnyType, UninhabitedType)):
5083+
tv = TypeVarType(
5084+
"T",
5085+
"T",
5086+
id=TypeVarId(-1, namespace="<lst>"),
5087+
values=[],
5088+
upper_bound=unified_type,
5089+
default=unified_type,
5090+
)
5091+
50655092
constructor = CallableType(
50665093
[tv],
50675094
[nodes.ARG_STAR],
@@ -5071,6 +5098,7 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
50715098
name=tag,
50725099
variables=[tv],
50735100
)
5101+
50745102
out = self.check_call(
50755103
constructor,
50765104
[(i.expr if isinstance(i, StarExpr) else i) for i in e.items],
@@ -5763,21 +5791,17 @@ def check_for_comp(self, e: GeneratorExpr | DictionaryComprehension) -> None:
57635791
_, sequence_type = self.chk.analyze_async_iterable_item_type(sequence)
57645792
else:
57655793
_, sequence_type = self.chk.analyze_iterable_item_type(sequence)
5794+
sequence_type = get_proper_type(sequence_type)
5795+
if isinstance(sequence_type, Instance):
5796+
if sequence_type.type.fullname == "builtins.list":
5797+
item_types = []
5798+
if isinstance(sequence, ListExpr):
5799+
for item in sequence.items:
5800+
item_type = self.accept(item)
5801+
item_types.append(item_type)
5802+
if item_types:
5803+
sequence_type = make_simplified_union(item_types)
57665804
self.chk.analyze_index_variables(index, sequence_type, True, e)
5767-
for condition in conditions:
5768-
self.accept(condition)
5769-
5770-
# values are only part of the comprehension when all conditions are true
5771-
true_map, false_map = self.chk.find_isinstance_check(condition)
5772-
5773-
if true_map:
5774-
self.chk.push_type_map(true_map)
5775-
5776-
if codes.REDUNDANT_EXPR in self.chk.options.enabled_error_codes:
5777-
if true_map is None:
5778-
self.msg.redundant_condition_in_comprehension(False, condition)
5779-
elif false_map is None:
5780-
self.msg.redundant_condition_in_comprehension(True, condition)
57815805

57825806
def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = False) -> Type:
57835807
self.accept(e.cond)

0 commit comments

Comments
 (0)