diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 11a9cffe18c3..b06f951b698a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4977,7 +4977,13 @@ def apply_type_arguments_to_callable( def visit_list_expr(self, e: ListExpr) -> Type: """Type check a list expression [...].""" - return self.check_lst_expr(e, "builtins.list", "") + if len(e.items) > 0: + item_types = [] + for item in e.items: + item_type = self.accept(item) + item_types.append(item_type) + element_type = make_simplified_union(item_types) + return self.chk.named_generic_type("builtins.list", [element_type]) def visit_set_expr(self, e: SetExpr) -> Type: return self.check_lst_expr(e, "builtins.set", "") @@ -5004,11 +5010,10 @@ def fast_container_type( values: list[Type] = [] for item in e.items: if isinstance(item, StarExpr): - # fallback to slow path self.resolved_type[e] = NoneType() return None values.append(self.accept(item)) - vt = join.join_type_list(values) + vt = make_simplified_union(values) if not allow_fast_container_literal(vt): self.resolved_type[e] = NoneType() return None @@ -5016,16 +5021,41 @@ def fast_container_type( self.resolved_type[e] = ct return ct + def infer_item_type(self, item_types: List[Type]) -> Type: + """Infer the item type for a list based on its elements.""" + joined_type = self.chk.join_types(*item_types) + proper_joined = get_proper_type(joined_type) + if ( + isinstance(proper_joined, Instance) + and proper_joined.type.fullname == "builtins.object" + and len(set(map(type, item_types))) > 1 + ): + # if we can't find a common supertype other than 'object', + # use a Union of the item types + return UnionType.make_simplified_union(item_types) + else: + # otherwise just use the common supertype + return joined_type + def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: str) -> Type: # fast path t = self.fast_container_type(e, fullname) if t: return t + # try to infer directly + if isinstance(e, ListExpr): + item_types = [self.accept(item) for item in e.items] + if not item_types: + item_type = AnyType(TypeOfAny.from_empty_collection) + else: + # use all the element types + item_type = self.infer_item_type(item_types) + return self.chk.named_generic_type(fullname, [item_type]) + + #!BUG this forces all list exprs to be forced (we don't want that) + # Translate into type checking a generic function call. - # Used for list and set expressions, as well as for tuples - # containing star expressions that don't refer to a - # Tuple. (Note: "lst" stands for list-set-tuple. :-) tv = TypeVarType( "T", "T", @@ -5034,6 +5064,31 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: upper_bound=self.object_type(), default=AnyType(TypeOfAny.from_omitted_generics), ) + + # get all elements and join types from list items + if isinstance(e, ListExpr): + item_types: list[Type] = [] + for item in e.items: + if isinstance(item, StarExpr): + starred_type = self.accept(item.expr) + starred_type = get_proper_type(starred_type) + if isinstance(starred_type, TupleType): + item_types.extend(starred_type.items) + else: + item_types.append(starred_type) + else: + item_types.append(self.accept(item)) + unified_type = join.join_type_list(item_types) + if not isinstance(unified_type, (AnyType, UninhabitedType)): + tv = TypeVarType( + "T", + "T", + id=TypeVarId(-1, namespace=""), + values=[], + upper_bound=unified_type, + default=unified_type, + ) + constructor = CallableType( [tv], [nodes.ARG_STAR], @@ -5043,6 +5098,7 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: name=tag, variables=[tv], ) + out = self.check_call( constructor, [(i.expr if isinstance(i, StarExpr) else i) for i in e.items], @@ -5735,21 +5791,17 @@ def check_for_comp(self, e: GeneratorExpr | DictionaryComprehension) -> None: _, sequence_type = self.chk.analyze_async_iterable_item_type(sequence) else: _, sequence_type = self.chk.analyze_iterable_item_type(sequence) + sequence_type = get_proper_type(sequence_type) + if isinstance(sequence_type, Instance): + if sequence_type.type.fullname == "builtins.list": + item_types = [] + if isinstance(sequence, ListExpr): + for item in sequence.items: + item_type = self.accept(item) + item_types.append(item_type) + if item_types: + sequence_type = make_simplified_union(item_types) self.chk.analyze_index_variables(index, sequence_type, True, e) - for condition in conditions: - self.accept(condition) - - # values are only part of the comprehension when all conditions are true - true_map, false_map = self.chk.find_isinstance_check(condition) - - if true_map: - self.chk.push_type_map(true_map) - - if codes.REDUNDANT_EXPR in self.chk.options.enabled_error_codes: - if true_map is None: - self.msg.redundant_condition_in_comprehension(False, condition) - elif false_map is None: - self.msg.redundant_condition_in_comprehension(True, condition) def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = False) -> Type: self.accept(e.cond) diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index cd26c9bb408a..58ae97d68224 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2452,3 +2452,23 @@ x + T # E: Unsupported left operand type for + ("int") T() # E: "TypeVar" not callable [builtins fixtures/tuple.pyi] [typing fixtures/typing-full.pyi] + +[case testListComprehensionWithTupleUnionTypeGenerator] +class A: pass +class B: pass +a = A() +b = B() +l3: list[A | B] = [x for x in (a, b)] +l3: list[A | B] = [x for x in [a, b]] + +[case testListComprehensionWithListUnionTypeGenerator] +a = A() +b = "foo" +l3: list[A | str] = [x for x in (a, b)] +l3: list[A | str] = [x for x in [a, b]] + +[case testListComprehensionWithImmutableTypeUnionTypeGenerator] +a = 3.0 +b = 3 +l3: list[float | int] = [x for x in (a, b)] +l3: list[float | int] = [x for x in [a, b]]