diff --git a/mypy/checker.py b/mypy/checker.py index 68f9bd4c1383..958b7d75dd4c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -146,6 +146,7 @@ find_member, infer_class_variances, is_callable_compatible, + is_enum_value_pair, is_equivalent, is_more_precise, is_proper_subtype, @@ -6783,13 +6784,16 @@ def should_coerce_inner(typ: Type) -> bool: expr_type = coerce_to_literal(expr_type) if not is_valid_target(get_proper_type(expr_type)): continue - if target and not is_same_type(target, expr_type): + if ( + target is not None + and not is_same_type(target, expr_type) + and not is_enum_value_pair(target, expr_type) + ): # We have multiple disjoint target types. So the 'if' branch # must be unreachable. return None, {} target = expr_type possible_target_indices.append(i) - # There's nothing we can currently infer if none of the operands are valid targets, # so we end early and infer nothing. if target is None: @@ -6862,8 +6866,28 @@ def should_coerce_inner(typ: Type) -> bool: # We intentionally use 'conditional_types' directly here instead of # 'self.conditional_types_with_intersection': we only compute ad-hoc # intersections when working with pure instances. - types = conditional_types(expr_type, target_type) - partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) + yes, no = conditional_types(expr_type, target_type) + # If we encounter `enum_value == 1` checks (enum vs literal), we do not want + # to narrow the former to literal and should preserve the enum identity. + # TODO: maybe we should infer literals here? + if ( + isinstance(get_proper_type(yes), LiteralType) + and isinstance(proper_expr := get_proper_type(expr_type), Instance) + and proper_expr.type.is_enum + ): + yes_items = [] + for name in proper_expr.type.enum_members: + e = proper_expr.type.get(name) + if ( + e is not None + and isinstance(proper_e := get_proper_type(e.type), Instance) + and proper_e.last_known_value == yes + ): + name_val = LiteralType(name, fallback=proper_expr) + yes_items.append(proper_expr.copy_modified(last_known_value=name_val)) + if yes_items: + yes = UnionType.make_union(yes_items) + partial_type_maps.append(conditional_types_to_typemaps(expr, yes, no)) return reduce_conditional_maps(partial_type_maps) @@ -9125,7 +9149,9 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]: # let's be conservative result.add("") elif isinstance(t, LiteralType): - result.update(_ambiguous_enum_variants([t.fallback])) + if t.fallback.type.is_enum: + result.update(_ambiguous_enum_variants([t.fallback])) + # Other literals (str, int, bool) cannot introduce any surprises elif isinstance(t, NoneType): pass else: diff --git a/mypy/meet.py b/mypy/meet.py index 349c15e668c3..2df18f2a3cf4 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -10,6 +10,7 @@ are_parameters_compatible, find_member, is_callable_compatible, + is_enum_value_pair, is_equivalent, is_proper_subtype, is_same_type, @@ -547,9 +548,16 @@ def _type_object_overlap(left: Type, right: Type) -> bool: right = right.fallback if isinstance(left, LiteralType) and isinstance(right, LiteralType): - if left.value == right.value: + if ( + left.value == right.value + and left.fallback.type.is_enum == right.fallback.type.is_enum + or is_enum_value_pair(left, right) + ): # If values are the same, we still need to check if fallbacks are overlapping, # this is done below. + # Enums are more interesting: + # * if both sides are enums, they should have same values + # * if exactly one of them is a enum, fallback compatibibility is enough left = left.fallback right = right.fallback else: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f3..772fce2a5f96 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -37,6 +37,7 @@ from mypy.options import Options from mypy.state import state from mypy.types import ( + ELLIPSIS_TYPE_NAMES, MYPYC_NATIVE_INT_NAMES, TUPLE_LIKE_INSTANCE_NAMES, TYPED_NAMEDTUPLE_NAMES, @@ -286,6 +287,35 @@ def is_same_type( ) +def is_enum_value_pair(a: Type, b: Type) -> bool: + a = get_proper_type(a) + b = get_proper_type(b) + + if not isinstance(a, LiteralType) or not isinstance(b, LiteralType): + return False + if b.fallback.type.is_enum: + a, b = b, a + if b.fallback.type.is_enum or not a.fallback.type.is_enum: + return False + # At this point we have a pair (enum literal, non-enum literal). + # Check that the non-enum fallback is compatible + if not is_subtype(a.fallback, b.fallback): + return False + assert isinstance(a.value, str) + enum_value = a.fallback.type.get(a.value) + if enum_value is None or enum_value.type is None: + return False + proper_value = get_proper_type(enum_value.type) + return isinstance(proper_value, Instance) and ( + proper_value.last_known_value == b + # TODO: this is too lax and should only be applied for enums defined in stubs, + # but checking that strictly requires access to the checker. This function + # is needed in `is_overlapping_types` and operates on a lower level, + # so doing this properly would be more difficult. + or proper_value.type.fullname in ELLIPSIS_TYPE_NAMES + ) + + # This is a common entry point for subtyping checks (both proper and non-proper). # Never call this private function directly, use the public versions. def _is_subtype( diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 3bcf9745a801..18ab4ee7794d 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2681,3 +2681,200 @@ reveal_type(Wrapper.Nested.FOO) # N: Revealed type is "Literal[__main__.Wrapper reveal_type(Wrapper.Nested.FOO.value) # N: Revealed type is "builtins.ellipsis" reveal_type(Wrapper.Nested.FOO._value_) # N: Revealed type is "builtins.ellipsis" [builtins fixtures/enum.pyi] + +[case testEnumItemsEqualityToLiterals] +# flags: --python-version=3.11 --strict-equality +from enum import Enum, StrEnum, IntEnum + +class A(str, Enum): + a = "b" + b = "a" + +# Every `if` block in this test should have an error on exactly one of two lines. +# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping) + +if A.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal['a']") + 1 + 'a' +if A.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]") + 1 + 'a' + +if A.a == A.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + 1 + 'a' + +class B(StrEnum): + a = "b" + b = "a" + +if B.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal['a']") + 1 + 'a' +if B.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]") + 1 + 'a' + +if B.a == B.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + 1 + 'a' + +if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + 1 + 'a' + +class C(IntEnum): + a = 0 + b = 1 + +if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") + 1 + 'a' +if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + 1 + 'a' + +if C.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[1]") + 1 + 'a' + +if C.a == C.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]") + 1 + 'a' + +class D(int, Enum): + a = 0 + b = 1 + +if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") + 1 + 'a' +if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + 1 + 'a' + +if D.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[1]") + 1 + 'a' + +if D.a == D.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]") + 1 + 'a' + +if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") + 1 + 'a' +[builtins fixtures/dict.pyi] + + +[case testEnumItemsEqualityToLiteralsInStub] +# flags: --python-version=3.11 --strict-equality +from mystub import A, B, C, D + +# Every `if` block in this test should have an error on exactly one of two lines. +# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping) + +if A.a == "a": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]") + 1 + 'a' + +if A.a == A.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + 1 + 'a' + +if B.a == "a": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]") + 1 + 'a' + +if B.a == B.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + 1 + 'a' + +if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + 1 + 'a' + +if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") + 1 + 'a' +if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + 1 + 'a' + +if C.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == 1: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if C.a == C.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]") + 1 + 'a' + +if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") + 1 + 'a' +if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + 1 + 'a' + +if D.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == 1: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if D.a == D.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]") + 1 + 'a' + +if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") + 1 + 'a' + +[file mystub.pyi] +from enum import Enum, StrEnum, IntEnum + +class A(str, Enum): + a = ... + b = ... + +class B(StrEnum): + a = ... + b = ... + +class C(int, Enum): + a = ... + b = ... + +class D(IntEnum): + a = ... + b = ... +[builtins fixtures/dict.pyi] + + +[case testEnumItemsEqualityToLiteralsWithAlias-xfail] +# flags: --python-version=3.11 --strict-equality +# TODO: mypy does not support enum member aliases now. +from enum import Enum, IntEnum + +class A(str, Enum): + a = "c" + b = a + +if A.a == A.b: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +class B(IntEnum): + a = 0 + b = a + +if B.a == B.b: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 7fffd3ce94e5..cd2487a6e42e 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2203,6 +2203,11 @@ def f3(x: IE | IE2) -> None: else: reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + if x == 1: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + else: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + def f4(x: IE | E) -> None: if x == IE.X: reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"