Skip to content

Commit 822dd30

Browse files
simplify literal elimination
1 parent 611ceaf commit 822dd30

File tree

2 files changed

+11
-19
lines changed

2 files changed

+11
-19
lines changed

mypy/test/testtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,8 @@ def test_simplified_union_with_str_instance_literals(self) -> None:
647647
def test_simplified_union_with_mixed_str_literals(self) -> None:
648648
fx = self.fx
649649

650+
self.assert_simplified_union([fx.lit_str1, fx.lit_str1_inst], fx.lit_str1_inst)
651+
650652
self.assert_simplified_union(
651653
[fx.lit_str1, fx.lit_str2, fx.lit_str3_inst],
652654
UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]),

mypy/typeops.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -603,25 +603,15 @@ def make_simplified_union(
603603
simplified_set = try_contracting_literals_in_union(simplified_set)
604604

605605
# Step 5: Combine Literals and Instances with LKVs, e.g. Literal[1]?, Literal[1] -> Literal[1]?
606-
new_items = []
607-
for item in simplified_set:
608-
if isinstance(item, LiteralType):
609-
# scan if there is an Instance with a last_known_value that matches
610-
for other in simplified_set:
611-
if (
612-
isinstance(other, Instance)
613-
and other.last_known_value is not None
614-
and item == other.last_known_value
615-
):
616-
# do not include item
617-
break
618-
else:
619-
new_items.append(item)
620-
else:
621-
# If the item is not a LiteralType, we can use it directly.
622-
new_items.append(item)
623-
624-
result = get_proper_type(UnionType.make_union(new_items, line, column))
606+
proper_items: list[ProperType] = list(map(get_proper_type, simplified_set))
607+
last_known_values: list[LiteralType | None] = [
608+
p_t.last_known_value if isinstance(p_t, Instance) else None for p_t in proper_items
609+
]
610+
simplified_set = [
611+
item for item, p_t in zip(simplified_set, proper_items) if p_t not in last_known_values
612+
]
613+
614+
result = get_proper_type(UnionType.make_union(simplified_set, line, column))
625615

626616
nitems = len(items)
627617
if nitems > 1 and (

0 commit comments

Comments
 (0)