From ab3ed33b65f53786d214be34d147a60edaf1b887 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 5 Aug 2025 01:47:22 +0200 Subject: [PATCH 1/5] Fix enum comparison with literal values --- mypy/meet.py | 5 +++- test-data/unit/check-enum.test | 47 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/mypy/meet.py b/mypy/meet.py index 349c15e668c3..229c8a75b78d 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -547,9 +547,12 @@ def _type_object_overlap(left: Type, right: Type) -> bool: right = right.fallback if isinstance(left, LiteralType) and isinstance(right, LiteralType): - if left.value == right.value: + if left.value == right.value or (left.fallback.type.is_enum ^ right.fallback.type.is_enum): # If values are the same, we still need to check if fallbacks are overlapping, # this is done below. + # Enums are more interesting: + # * if both sides are enums, they should have same values + # * if exactly one of them is a enum, fallback compatibibility is enough left = left.fallback right = right.fallback else: diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 3bcf9745a801..f10975926756 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2681,3 +2681,50 @@ reveal_type(Wrapper.Nested.FOO) # N: Revealed type is "Literal[__main__.Wrapper reveal_type(Wrapper.Nested.FOO.value) # N: Revealed type is "builtins.ellipsis" reveal_type(Wrapper.Nested.FOO._value_) # N: Revealed type is "builtins.ellipsis" [builtins fixtures/enum.pyi] + +[case testEnumItemsEqualityToLiterals] +# flags: --python-version=3.11 --strict-equality +from enum import Enum, StrEnum, IntEnum + +class A(str, Enum): + a = "b" + b = "a" + +A.a == "a" +A.a == "b" + +A.a == A.a +A.a == A.b # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + +class B(StrEnum): + a = "b" + b = "a" + +B.a == "a" +B.a == "b" + +B.a == B.a +B.a == B.b # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + +B.a == A.a # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + +class C(IntEnum): + a = 0 + +C.a == "a" # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") +C.a == "b" # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + +C.a == C.a +C.a == C.b + +class D(int, Enum): + a = 0 + +D.a == "a" # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") +D.a == "b" # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + +D.a == D.a +D.a == D.b + +D.a == C.a # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") +[builtins fixtures/dict.pyi] From 350aa53ed54d7dd3668b44487a18dff217016cf9 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 7 Aug 2025 18:49:49 +0200 Subject: [PATCH 2/5] Sync reachability and comparison overlap checks --- mypy/checker.py | 12 ++- mypy/meet.py | 7 +- mypy/subtypes.py | 29 +++++ test-data/unit/check-enum.test | 186 +++++++++++++++++++++++++++++---- 4 files changed, 212 insertions(+), 22 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 68f9bd4c1383..f1ec4cbfdeaf 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -146,6 +146,7 @@ find_member, infer_class_variances, is_callable_compatible, + is_enum_value_pair, is_equivalent, is_more_precise, is_proper_subtype, @@ -6783,13 +6784,16 @@ def should_coerce_inner(typ: Type) -> bool: expr_type = coerce_to_literal(expr_type) if not is_valid_target(get_proper_type(expr_type)): continue - if target and not is_same_type(target, expr_type): + if ( + target is not None + and not is_same_type(target, expr_type) + and not is_enum_value_pair(target, expr_type) + ): # We have multiple disjoint target types. So the 'if' branch # must be unreachable. return None, {} target = expr_type possible_target_indices.append(i) - # There's nothing we can currently infer if none of the operands are valid targets, # so we end early and infer nothing. if target is None: @@ -9125,7 +9129,9 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]: # let's be conservative result.add("") elif isinstance(t, LiteralType): - result.update(_ambiguous_enum_variants([t.fallback])) + if t.fallback.type.is_enum: + result.update(_ambiguous_enum_variants([t.fallback])) + # Other literals (str, int, bool) cannot introduce any surprises elif isinstance(t, NoneType): pass else: diff --git a/mypy/meet.py b/mypy/meet.py index 229c8a75b78d..2df18f2a3cf4 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -10,6 +10,7 @@ are_parameters_compatible, find_member, is_callable_compatible, + is_enum_value_pair, is_equivalent, is_proper_subtype, is_same_type, @@ -547,7 +548,11 @@ def _type_object_overlap(left: Type, right: Type) -> bool: right = right.fallback if isinstance(left, LiteralType) and isinstance(right, LiteralType): - if left.value == right.value or (left.fallback.type.is_enum ^ right.fallback.type.is_enum): + if ( + left.value == right.value + and left.fallback.type.is_enum == right.fallback.type.is_enum + or is_enum_value_pair(left, right) + ): # If values are the same, we still need to check if fallbacks are overlapping, # this is done below. # Enums are more interesting: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f3..f1be900c6232 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -37,6 +37,7 @@ from mypy.options import Options from mypy.state import state from mypy.types import ( + ELLIPSIS_TYPE_NAMES, MYPYC_NATIVE_INT_NAMES, TUPLE_LIKE_INSTANCE_NAMES, TYPED_NAMEDTUPLE_NAMES, @@ -286,6 +287,34 @@ def is_same_type( ) +def is_enum_value_pair(a: ProperType, b: ProperType) -> bool: + if not isinstance(a, LiteralType) or not isinstance(b, LiteralType): + return False + if b.fallback.type.is_enum: + a, b = b, a + if b.fallback.type.is_enum: + return False + # At this point we have a pair (non-enum literal, enum literal). + # Check that the non-enum fallback is compatible + if not is_subtype(a.fallback, b.fallback): + return False + assert isinstance(a.value, str) + enum_value = a.fallback.type.get(a.value) + return ( + enum_value is not None + and enum_value.type is not None + and isinstance(enum_value.type, Instance) + and ( + enum_value.type.last_known_value == b + # TODO: this is too lax and should only be applied for enums defined in stubs, + # but checking that strictly requires access to the checker. This function + # is needed in `is_overlapping_types` and operates on a lower level, + # so doing this properly would be more difficult. + or enum_value.type.type.fullname in ELLIPSIS_TYPE_NAMES + ) + ) + + # This is a common entry point for subtyping checks (both proper and non-proper). # Never call this private function directly, use the public versions. def _is_subtype( diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index f10975926756..18ab4ee7794d 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2690,41 +2690,191 @@ class A(str, Enum): a = "b" b = "a" -A.a == "a" -A.a == "b" +# Every `if` block in this test should have an error on exactly one of two lines. +# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping) -A.a == A.a -A.a == A.b # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") +if A.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal['a']") + 1 + 'a' +if A.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]") + 1 + 'a' + +if A.a == A.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + 1 + 'a' class B(StrEnum): a = "b" b = "a" -B.a == "a" -B.a == "b" +if B.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal['a']") + 1 + 'a' +if B.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") -B.a == B.a -B.a == B.b # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") +if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]") + 1 + 'a' -B.a == A.a # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") +if B.a == B.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + 1 + 'a' + +if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + 1 + 'a' class C(IntEnum): a = 0 + b = 1 + +if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") + 1 + 'a' +if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + 1 + 'a' -C.a == "a" # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") -C.a == "b" # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") +if C.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[1]") + 1 + 'a' -C.a == C.a -C.a == C.b +if C.a == C.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]") + 1 + 'a' class D(int, Enum): a = 0 + b = 1 + +if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") + 1 + 'a' +if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + 1 + 'a' + +if D.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[1]") + 1 + 'a' + +if D.a == D.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]") + 1 + 'a' -D.a == "a" # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") -D.a == "b" # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") +if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") + 1 + 'a' +[builtins fixtures/dict.pyi] + + +[case testEnumItemsEqualityToLiteralsInStub] +# flags: --python-version=3.11 --strict-equality +from mystub import A, B, C, D + +# Every `if` block in this test should have an error on exactly one of two lines. +# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping) + +if A.a == "a": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]") + 1 + 'a' + +if A.a == A.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + 1 + 'a' + +if B.a == "a": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]") + 1 + 'a' + +if B.a == B.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + 1 + 'a' + +if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + 1 + 'a' + +if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") + 1 + 'a' +if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + 1 + 'a' + +if C.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == 1: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if C.a == C.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]") + 1 + 'a' + +if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") + 1 + 'a' +if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + 1 + 'a' + +if D.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == 1: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if D.a == D.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]") + 1 + 'a' + +if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") + 1 + 'a' + +[file mystub.pyi] +from enum import Enum, StrEnum, IntEnum + +class A(str, Enum): + a = ... + b = ... -D.a == D.a -D.a == D.b +class B(StrEnum): + a = ... + b = ... + +class C(int, Enum): + a = ... + b = ... + +class D(IntEnum): + a = ... + b = ... +[builtins fixtures/dict.pyi] + + +[case testEnumItemsEqualityToLiteralsWithAlias-xfail] +# flags: --python-version=3.11 --strict-equality +# TODO: mypy does not support enum member aliases now. +from enum import Enum, IntEnum + +class A(str, Enum): + a = "c" + b = a + +if A.a == A.b: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +class B(IntEnum): + a = 0 + b = a -D.a == C.a # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") +if B.a == B.b: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") [builtins fixtures/dict.pyi] From ebfc647bcbd243d191f71cfdef0cba6a1a73ce41 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 7 Aug 2025 19:10:47 +0200 Subject: [PATCH 3/5] Fix selfcheck --- mypy/subtypes.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index f1be900c6232..b2233ace9831 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -287,7 +287,10 @@ def is_same_type( ) -def is_enum_value_pair(a: ProperType, b: ProperType) -> bool: +def is_enum_value_pair(a: Type, b: Type) -> bool: + a = get_proper_type(a) + b = get_proper_type(b) + if not isinstance(a, LiteralType) or not isinstance(b, LiteralType): return False if b.fallback.type.is_enum: @@ -300,18 +303,16 @@ def is_enum_value_pair(a: ProperType, b: ProperType) -> bool: return False assert isinstance(a.value, str) enum_value = a.fallback.type.get(a.value) - return ( - enum_value is not None - and enum_value.type is not None - and isinstance(enum_value.type, Instance) - and ( - enum_value.type.last_known_value == b - # TODO: this is too lax and should only be applied for enums defined in stubs, - # but checking that strictly requires access to the checker. This function - # is needed in `is_overlapping_types` and operates on a lower level, - # so doing this properly would be more difficult. - or enum_value.type.type.fullname in ELLIPSIS_TYPE_NAMES - ) + if enum_value is None or enum_value.type is None: + return False + proper_value = get_proper_type(enum_value.type) + return isinstance(proper_value, Instance) and ( + proper_value.last_known_value == b + # TODO: this is too lax and should only be applied for enums defined in stubs, + # but checking that strictly requires access to the checker. This function + # is needed in `is_overlapping_types` and operates on a lower level, + # so doing this properly would be more difficult. + or proper_value.type.fullname in ELLIPSIS_TYPE_NAMES ) From 0001936504a9c65ba509156bd6b06b4024548f18 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 7 Aug 2025 19:34:03 +0200 Subject: [PATCH 4/5] Oops --- mypy/subtypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index b2233ace9831..772fce2a5f96 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -295,9 +295,9 @@ def is_enum_value_pair(a: Type, b: Type) -> bool: return False if b.fallback.type.is_enum: a, b = b, a - if b.fallback.type.is_enum: + if b.fallback.type.is_enum or not a.fallback.type.is_enum: return False - # At this point we have a pair (non-enum literal, enum literal). + # At this point we have a pair (enum literal, non-enum literal). # Check that the non-enum fallback is compatible if not is_subtype(a.fallback, b.fallback): return False From c87c7074149944b7578f391f6bb3a2de270026fa Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 7 Aug 2025 20:46:46 +0200 Subject: [PATCH 5/5] Prevent narrowing by equality to overlapping literals (discarding enum info). --- mypy/checker.py | 24 ++++++++++++++++++++++-- test-data/unit/check-narrowing.test | 5 +++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index f1ec4cbfdeaf..958b7d75dd4c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6866,8 +6866,28 @@ def should_coerce_inner(typ: Type) -> bool: # We intentionally use 'conditional_types' directly here instead of # 'self.conditional_types_with_intersection': we only compute ad-hoc # intersections when working with pure instances. - types = conditional_types(expr_type, target_type) - partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) + yes, no = conditional_types(expr_type, target_type) + # If we encounter `enum_value == 1` checks (enum vs literal), we do not want + # to narrow the former to literal and should preserve the enum identity. + # TODO: maybe we should infer literals here? + if ( + isinstance(get_proper_type(yes), LiteralType) + and isinstance(proper_expr := get_proper_type(expr_type), Instance) + and proper_expr.type.is_enum + ): + yes_items = [] + for name in proper_expr.type.enum_members: + e = proper_expr.type.get(name) + if ( + e is not None + and isinstance(proper_e := get_proper_type(e.type), Instance) + and proper_e.last_known_value == yes + ): + name_val = LiteralType(name, fallback=proper_expr) + yes_items.append(proper_expr.copy_modified(last_known_value=name_val)) + if yes_items: + yes = UnionType.make_union(yes_items) + partial_type_maps.append(conditional_types_to_typemaps(expr, yes, no)) return reduce_conditional_maps(partial_type_maps) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 7fffd3ce94e5..cd2487a6e42e 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2203,6 +2203,11 @@ def f3(x: IE | IE2) -> None: else: reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + if x == 1: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + else: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + def f4(x: IE | E) -> None: if x == IE.X: reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"