Skip to content

Better support for SomeEnum.item == some_literal #19594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
find_member,
infer_class_variances,
is_callable_compatible,
is_enum_value_pair,
is_equivalent,
is_more_precise,
is_proper_subtype,
Expand Down Expand Up @@ -6783,13 +6784,16 @@ def should_coerce_inner(typ: Type) -> bool:
expr_type = coerce_to_literal(expr_type)
if not is_valid_target(get_proper_type(expr_type)):
continue
if target and not is_same_type(target, expr_type):
if (
target is not None
and not is_same_type(target, expr_type)
and not is_enum_value_pair(target, expr_type)
):
# We have multiple disjoint target types. So the 'if' branch
# must be unreachable.
return None, {}
target = expr_type
possible_target_indices.append(i)

# There's nothing we can currently infer if none of the operands are valid targets,
# so we end early and infer nothing.
if target is None:
Expand Down Expand Up @@ -6862,8 +6866,28 @@ def should_coerce_inner(typ: Type) -> bool:
# We intentionally use 'conditional_types' directly here instead of
# 'self.conditional_types_with_intersection': we only compute ad-hoc
# intersections when working with pure instances.
types = conditional_types(expr_type, target_type)
partial_type_maps.append(conditional_types_to_typemaps(expr, *types))
yes, no = conditional_types(expr_type, target_type)
# If we encounter `enum_value == 1` checks (enum vs literal), we do not want
# to narrow the former to literal and should preserve the enum identity.
# TODO: maybe we should infer literals here?
if (
isinstance(get_proper_type(yes), LiteralType)
and isinstance(proper_expr := get_proper_type(expr_type), Instance)
and proper_expr.type.is_enum
):
yes_items = []
for name in proper_expr.type.enum_members:
e = proper_expr.type.get(name)
if (
e is not None
and isinstance(proper_e := get_proper_type(e.type), Instance)
and proper_e.last_known_value == yes
):
name_val = LiteralType(name, fallback=proper_expr)
yes_items.append(proper_expr.copy_modified(last_known_value=name_val))
if yes_items:
yes = UnionType.make_union(yes_items)
partial_type_maps.append(conditional_types_to_typemaps(expr, yes, no))

return reduce_conditional_maps(partial_type_maps)

Expand Down Expand Up @@ -9125,7 +9149,9 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]:
# let's be conservative
result.add("<other>")
elif isinstance(t, LiteralType):
result.update(_ambiguous_enum_variants([t.fallback]))
if t.fallback.type.is_enum:
result.update(_ambiguous_enum_variants([t.fallback]))
# Other literals (str, int, bool) cannot introduce any surprises
elif isinstance(t, NoneType):
pass
else:
Expand Down
10 changes: 9 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
are_parameters_compatible,
find_member,
is_callable_compatible,
is_enum_value_pair,
is_equivalent,
is_proper_subtype,
is_same_type,
Expand Down Expand Up @@ -547,9 +548,16 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
right = right.fallback

if isinstance(left, LiteralType) and isinstance(right, LiteralType):
if left.value == right.value:
if (
left.value == right.value
and left.fallback.type.is_enum == right.fallback.type.is_enum
or is_enum_value_pair(left, right)
):
# If values are the same, we still need to check if fallbacks are overlapping,
# this is done below.
# Enums are more interesting:
# * if both sides are enums, they should have same values
# * if exactly one of them is a enum, fallback compatibibility is enough
left = left.fallback
right = right.fallback
else:
Expand Down
30 changes: 30 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mypy.options import Options
from mypy.state import state
from mypy.types import (
ELLIPSIS_TYPE_NAMES,
MYPYC_NATIVE_INT_NAMES,
TUPLE_LIKE_INSTANCE_NAMES,
TYPED_NAMEDTUPLE_NAMES,
Expand Down Expand Up @@ -286,6 +287,35 @@ def is_same_type(
)


def is_enum_value_pair(a: Type, b: Type) -> bool:
a = get_proper_type(a)
b = get_proper_type(b)

if not isinstance(a, LiteralType) or not isinstance(b, LiteralType):
return False
if b.fallback.type.is_enum:
a, b = b, a
if b.fallback.type.is_enum or not a.fallback.type.is_enum:
return False
# At this point we have a pair (enum literal, non-enum literal).
# Check that the non-enum fallback is compatible
if not is_subtype(a.fallback, b.fallback):
return False
assert isinstance(a.value, str)
enum_value = a.fallback.type.get(a.value)
if enum_value is None or enum_value.type is None:
return False
proper_value = get_proper_type(enum_value.type)
return isinstance(proper_value, Instance) and (
proper_value.last_known_value == b
# TODO: this is too lax and should only be applied for enums defined in stubs,
# but checking that strictly requires access to the checker. This function
# is needed in `is_overlapping_types` and operates on a lower level,
# so doing this properly would be more difficult.
or proper_value.type.fullname in ELLIPSIS_TYPE_NAMES
)


# This is a common entry point for subtyping checks (both proper and non-proper).
# Never call this private function directly, use the public versions.
def _is_subtype(
Expand Down
197 changes: 197 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -2681,3 +2681,200 @@ reveal_type(Wrapper.Nested.FOO) # N: Revealed type is "Literal[__main__.Wrapper
reveal_type(Wrapper.Nested.FOO.value) # N: Revealed type is "builtins.ellipsis"
reveal_type(Wrapper.Nested.FOO._value_) # N: Revealed type is "builtins.ellipsis"
[builtins fixtures/enum.pyi]

[case testEnumItemsEqualityToLiterals]
# flags: --python-version=3.11 --strict-equality
from enum import Enum, StrEnum, IntEnum

class A(str, Enum):
a = "b"
b = "a"

# Every `if` block in this test should have an error on exactly one of two lines.
# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping)

if A.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal['a']")
1 + 'a'
if A.a == "b":
1 + 'a' # E: Unsupported operand types for + ("int" and "str")

if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]")
1 + 'a'

if A.a == A.a:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]")
1 + 'a'

class B(StrEnum):
a = "b"
b = "a"

if B.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal['a']")
1 + 'a'
if B.a == "b":
1 + 'a' # E: Unsupported operand types for + ("int" and "str")

if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]")
1 + 'a'

if B.a == B.a:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]")
1 + 'a'

if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]")
1 + 'a'

class C(IntEnum):
a = 0
b = 1

if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']")
1 + 'a'
if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']")
1 + 'a'

if C.a == 0:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if C.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[1]")
1 + 'a'

if C.a == C.a:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]")
1 + 'a'

class D(int, Enum):
a = 0
b = 1

if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']")
1 + 'a'
if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']")
1 + 'a'

if D.a == 0:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if D.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[1]")
1 + 'a'

if D.a == D.a:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]")
1 + 'a'

if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]")
1 + 'a'
[builtins fixtures/dict.pyi]


[case testEnumItemsEqualityToLiteralsInStub]
# flags: --python-version=3.11 --strict-equality
from mystub import A, B, C, D

# Every `if` block in this test should have an error on exactly one of two lines.
# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping)

if A.a == "a":
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if A.a == "b":
1 + 'a' # E: Unsupported operand types for + ("int" and "str")

if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]")
1 + 'a'

if A.a == A.a:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]")
1 + 'a'

if B.a == "a":
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if B.a == "b":
1 + 'a' # E: Unsupported operand types for + ("int" and "str")

if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]")
1 + 'a'

if B.a == B.a:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]")
1 + 'a'

if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]")
1 + 'a'

if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']")
1 + 'a'
if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']")
1 + 'a'

if C.a == 0:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if C.a == 1:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")

if C.a == C.a:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]")
1 + 'a'

if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']")
1 + 'a'
if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']")
1 + 'a'

if D.a == 0:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if D.a == 1:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")

if D.a == D.a:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]")
1 + 'a'

if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]")
1 + 'a'

[file mystub.pyi]
from enum import Enum, StrEnum, IntEnum

class A(str, Enum):
a = ...
b = ...

class B(StrEnum):
a = ...
b = ...

class C(int, Enum):
a = ...
b = ...

class D(IntEnum):
a = ...
b = ...
[builtins fixtures/dict.pyi]


[case testEnumItemsEqualityToLiteralsWithAlias-xfail]
# flags: --python-version=3.11 --strict-equality
# TODO: mypy does not support enum member aliases now.
from enum import Enum, IntEnum

class A(str, Enum):
a = "c"
b = a

if A.a == A.b:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")

class B(IntEnum):
a = 0
b = a

if B.a == B.b:
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
[builtins fixtures/dict.pyi]
5 changes: 5 additions & 0 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -2203,6 +2203,11 @@ def f3(x: IE | IE2) -> None:
else:
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]"

if x == 1:
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]"
else:
reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]"

def f4(x: IE | E) -> None:
if x == IE.X:
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"
Expand Down
Loading