Skip to content

Commit 83e7789

Browse files
simplify literal elimination
1 parent aa4b74f commit 83e7789

File tree

2 files changed

+11
-19
lines changed

2 files changed

+11
-19
lines changed

mypy/test/testtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,8 @@ def test_simplified_union_with_str_instance_literals(self) -> None:
647647
def test_simplified_union_with_mixed_str_literals(self) -> None:
648648
fx = self.fx
649649

650+
self.assert_simplified_union([fx.lit_str1, fx.lit_str1_inst], fx.lit_str1_inst)
651+
650652
self.assert_simplified_union(
651653
[fx.lit_str1, fx.lit_str2, fx.lit_str3_inst],
652654
UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]),

mypy/typeops.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -602,25 +602,15 @@ def make_simplified_union(
602602
simplified_set = try_contracting_literals_in_union(simplified_set)
603603

604604
# Step 5: Combine Literals and Instances with LKVs, e.g. Literal[1]?, Literal[1] -> Literal[1]?
605-
new_items = []
606-
for item in simplified_set:
607-
if isinstance(item, LiteralType):
608-
# scan if there is an Instance with a last_known_value that matches
609-
for other in simplified_set:
610-
if (
611-
isinstance(other, Instance)
612-
and other.last_known_value is not None
613-
and item == other.last_known_value
614-
):
615-
# do not include item
616-
break
617-
else:
618-
new_items.append(item)
619-
else:
620-
# If the item is not a LiteralType, we can use it directly.
621-
new_items.append(item)
622-
623-
result = get_proper_type(UnionType.make_union(new_items, line, column))
605+
proper_items: list[ProperType] = list(map(get_proper_type, simplified_set))
606+
last_known_values: list[LiteralType | None] = [
607+
p_t.last_known_value if isinstance(p_t, Instance) else None for p_t in proper_items
608+
]
609+
simplified_set = [
610+
item for item, p_t in zip(simplified_set, proper_items) if p_t not in last_known_values
611+
]
612+
613+
result = get_proper_type(UnionType.make_union(simplified_set, line, column))
624614

625615
nitems = len(items)
626616
if nitems > 1 and (

0 commit comments

Comments
 (0)