From fe3322acd41d4ac68c15c5cae0b50b4e76d6e074 Mon Sep 17 00:00:00 2001 From: Andrew Youn <52907065+ay0503@users.noreply.github.com> Date: Thu, 28 Nov 2024 01:13:48 -0500 Subject: [PATCH 1/5] Force list expression to infer Union type of member types. --- mypy/checkexpr.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 11a9cffe18c3..cb7c42a97e46 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5016,12 +5016,39 @@ def fast_container_type( self.resolved_type[e] = ct return ct + def infer_item_type(self, item_types: List[Type]) -> Type: + """Infer the item type for a list based on its elements.""" + joined_type = self.chk.join_types(*item_types) + proper_joined = get_proper_type(joined_type) + if ( + isinstance(proper_joined, Instance) + and proper_joined.type.fullname == "builtins.object" + and len(set(map(type, item_types))) > 1 + ): + # if we can't find a common supertype other than 'object', + # use a Union of the item types + return UnionType.make_simplified_union(item_types) + else: + # otherwise just use the common supertype + return joined_type + def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: str) -> Type: # fast path t = self.fast_container_type(e, fullname) if t: return t + # if a ListExpr, just infer the item type directly + if isinstance(e, ListExpr): + item_types = [self.accept(item) for item in e.items] + if not item_types: + # empty list, default to Any + item_type = AnyType(TypeOfAny.from_empty_collection) + else: + # attempt to find a common supertype + item_type = self.infer_item_type(item_types) + return self.chk.named_generic_type(fullname, [item_type]) + # Translate into type checking a generic function call. # Used for list and set expressions, as well as for tuples # containing star expressions that don't refer to a From e5893cf9e1e02aa5ab43e750ae8a58c3683d969d Mon Sep 17 00:00:00 2001 From: Andrew Youn <52907065+ay0503@users.noreply.github.com> Date: Thu, 28 Nov 2024 01:23:18 -0500 Subject: [PATCH 2/5] Add unit test check-list-expr. --- test-data/unit/check-expressions.test | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index cd26c9bb408a..2a20b3f77add 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2452,3 +2452,10 @@ x + T # E: Unsupported left operand type for + ("int") T() # E: "TypeVar" not callable [builtins fixtures/tuple.pyi] [typing fixtures/typing-full.pyi] + +[case testListComprehensionWithUnionTypeGenerator] +class A: pass +class B: pass +a = A() +b = B() +l3: list[A | B] = [x for x in [a, b]] \ No newline at end of file From b03bedcbaf64b8f1679fde2ae6fb5598fe2f2d87 Mon Sep 17 00:00:00 2001 From: Andrew Youn <52907065+ay0503@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:59:46 -0500 Subject: [PATCH 3/5] Add additional tests for immutable/mutable type cases and tuple vs. list iterators. --- test-data/unit/check-expressions.test | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 2a20b3f77add..58ae97d68224 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2453,9 +2453,22 @@ T() # E: "TypeVar" not callable [builtins fixtures/tuple.pyi] [typing fixtures/typing-full.pyi] -[case testListComprehensionWithUnionTypeGenerator] +[case testListComprehensionWithTupleUnionTypeGenerator] class A: pass class B: pass a = A() b = B() -l3: list[A | B] = [x for x in [a, b]] \ No newline at end of file +l3: list[A | B] = [x for x in (a, b)] +l3: list[A | B] = [x for x in [a, b]] + +[case testListComprehensionWithListUnionTypeGenerator] +a = A() +b = "foo" +l3: list[A | str] = [x for x in (a, b)] +l3: list[A | str] = [x for x in [a, b]] + +[case testListComprehensionWithImmutableTypeUnionTypeGenerator] +a = 3.0 +b = 3 +l3: list[float | int] = [x for x in (a, b)] +l3: list[float | int] = [x for x in [a, b]] From 0f28a5f3e375b30a072f03bce7be9f31fe324355 Mon Sep 17 00:00:00 2001 From: Andrew Youn <52907065+ay0503@users.noreply.github.com> Date: Fri, 6 Dec 2024 18:04:16 -0500 Subject: [PATCH 4/5] Add status comments. --- mypy/checkexpr.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index cb7c42a97e46..501c424611aa 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5038,17 +5038,18 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: if t: return t - # if a ListExpr, just infer the item type directly + # try to infer directly if isinstance(e, ListExpr): item_types = [self.accept(item) for item in e.items] if not item_types: - # empty list, default to Any item_type = AnyType(TypeOfAny.from_empty_collection) else: - # attempt to find a common supertype + # use all the element types item_type = self.infer_item_type(item_types) return self.chk.named_generic_type(fullname, [item_type]) + #!BUG this forces all list exprs to be forced (we don't want that) + # Translate into type checking a generic function call. # Used for list and set expressions, as well as for tuples # containing star expressions that don't refer to a From f5cc35788085ceb895f22e407ff96697c7bf7b4b Mon Sep 17 00:00:00 2001 From: Andrew Youn <52907065+ay0503@users.noreply.github.com> Date: Sun, 8 Dec 2024 01:03:46 -0500 Subject: [PATCH 5/5] Modify approach to collectively infer union instead of forcing. --- mypy/checkexpr.py | 64 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 501c424611aa..b06f951b698a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4977,7 +4977,13 @@ def apply_type_arguments_to_callable( def visit_list_expr(self, e: ListExpr) -> Type: """Type check a list expression [...].""" - return self.check_lst_expr(e, "builtins.list", "") + if len(e.items) > 0: + item_types = [] + for item in e.items: + item_type = self.accept(item) + item_types.append(item_type) + element_type = make_simplified_union(item_types) + return self.chk.named_generic_type("builtins.list", [element_type]) def visit_set_expr(self, e: SetExpr) -> Type: return self.check_lst_expr(e, "builtins.set", "") @@ -5004,11 +5010,10 @@ def fast_container_type( values: list[Type] = [] for item in e.items: if isinstance(item, StarExpr): - # fallback to slow path self.resolved_type[e] = NoneType() return None values.append(self.accept(item)) - vt = join.join_type_list(values) + vt = make_simplified_union(values) if not allow_fast_container_literal(vt): self.resolved_type[e] = NoneType() return None @@ -5051,9 +5056,6 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: #!BUG this forces all list exprs to be forced (we don't want that) # Translate into type checking a generic function call. - # Used for list and set expressions, as well as for tuples - # containing star expressions that don't refer to a - # Tuple. (Note: "lst" stands for list-set-tuple. :-) tv = TypeVarType( "T", "T", @@ -5062,6 +5064,31 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: upper_bound=self.object_type(), default=AnyType(TypeOfAny.from_omitted_generics), ) + + # get all elements and join types from list items + if isinstance(e, ListExpr): + item_types: list[Type] = [] + for item in e.items: + if isinstance(item, StarExpr): + starred_type = self.accept(item.expr) + starred_type = get_proper_type(starred_type) + if isinstance(starred_type, TupleType): + item_types.extend(starred_type.items) + else: + item_types.append(starred_type) + else: + item_types.append(self.accept(item)) + unified_type = join.join_type_list(item_types) + if not isinstance(unified_type, (AnyType, UninhabitedType)): + tv = TypeVarType( + "T", + "T", + id=TypeVarId(-1, namespace=""), + values=[], + upper_bound=unified_type, + default=unified_type, + ) + constructor = CallableType( [tv], [nodes.ARG_STAR], @@ -5071,6 +5098,7 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: name=tag, variables=[tv], ) + out = self.check_call( constructor, [(i.expr if isinstance(i, StarExpr) else i) for i in e.items], @@ -5763,21 +5791,17 @@ def check_for_comp(self, e: GeneratorExpr | DictionaryComprehension) -> None: _, sequence_type = self.chk.analyze_async_iterable_item_type(sequence) else: _, sequence_type = self.chk.analyze_iterable_item_type(sequence) + sequence_type = get_proper_type(sequence_type) + if isinstance(sequence_type, Instance): + if sequence_type.type.fullname == "builtins.list": + item_types = [] + if isinstance(sequence, ListExpr): + for item in sequence.items: + item_type = self.accept(item) + item_types.append(item_type) + if item_types: + sequence_type = make_simplified_union(item_types) self.chk.analyze_index_variables(index, sequence_type, True, e) - for condition in conditions: - self.accept(condition) - - # values are only part of the comprehension when all conditions are true - true_map, false_map = self.chk.find_isinstance_check(condition) - - if true_map: - self.chk.push_type_map(true_map) - - if codes.REDUNDANT_EXPR in self.chk.options.enabled_error_codes: - if true_map is None: - self.msg.redundant_condition_in_comprehension(False, condition) - elif false_map is None: - self.msg.redundant_condition_in_comprehension(True, condition) def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = False) -> Type: self.accept(e.cond)