Skip to content

Commit c87c707

Browse files
committed
Prevent narrowing by equality to overlapping literals (discarding enum info).
1 parent 0001936 commit c87c707

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

mypy/checker.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6866,8 +6866,28 @@ def should_coerce_inner(typ: Type) -> bool:
68666866
# We intentionally use 'conditional_types' directly here instead of
68676867
# 'self.conditional_types_with_intersection': we only compute ad-hoc
68686868
# intersections when working with pure instances.
6869-
types = conditional_types(expr_type, target_type)
6870-
partial_type_maps.append(conditional_types_to_typemaps(expr, *types))
6869+
yes, no = conditional_types(expr_type, target_type)
6870+
# If we encounter `enum_value == 1` checks (enum vs literal), we do not want
6871+
# to narrow the former to literal and should preserve the enum identity.
6872+
# TODO: maybe we should infer literals here?
6873+
if (
6874+
isinstance(get_proper_type(yes), LiteralType)
6875+
and isinstance(proper_expr := get_proper_type(expr_type), Instance)
6876+
and proper_expr.type.is_enum
6877+
):
6878+
yes_items = []
6879+
for name in proper_expr.type.enum_members:
6880+
e = proper_expr.type.get(name)
6881+
if (
6882+
e is not None
6883+
and isinstance(proper_e := get_proper_type(e.type), Instance)
6884+
and proper_e.last_known_value == yes
6885+
):
6886+
name_val = LiteralType(name, fallback=proper_expr)
6887+
yes_items.append(proper_expr.copy_modified(last_known_value=name_val))
6888+
if yes_items:
6889+
yes = UnionType.make_union(yes_items)
6890+
partial_type_maps.append(conditional_types_to_typemaps(expr, yes, no))
68716891

68726892
return reduce_conditional_maps(partial_type_maps)
68736893

test-data/unit/check-narrowing.test

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,6 +2203,11 @@ def f3(x: IE | IE2) -> None:
22032203
else:
22042204
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]"
22052205

2206+
if x == 1:
2207+
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]"
2208+
else:
2209+
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]"
2210+
22062211
def f4(x: IE | E) -> None:
22072212
if x == IE.X:
22082213
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"

0 commit comments

Comments
 (0)