Skip to content

Commit 79262e7

Browse files
committed
Handle (str, Enum) and (int, Enum) subclasses narrowing
1 parent 6c4f0aa commit 79262e7

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

mypy/checker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9208,7 +9208,8 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]:
92089208
if t.last_known_value:
92099209
result.update(_ambiguous_enum_variants([t.last_known_value]))
92109210
elif t.type.is_enum and any(
9211-
base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro
9211+
base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int")
9212+
for base in t.type.mro
92129213
):
92139214
result.add(t.type.fullname)
92149215
elif not t.type.is_enum:

test-data/unit/check-enum.test

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2881,3 +2881,108 @@ class B(IntEnum):
28812881
if B.a == B.b:
28822882
1 + 'a' # E: Unsupported operand types for + ("int" and "str")
28832883
[builtins fixtures/dict.pyi]
2884+
2885+
2886+
[case testEnumNarrowingByEqualityToLiterals]
2887+
# flags: --python-version=3.11 --strict-equality
2888+
from enum import Enum, StrEnum, IntEnum
2889+
2890+
# Every `if` block in this test should either reveal or report non-overlapping.
2891+
2892+
class A(str, Enum):
2893+
a = "b"
2894+
b = "a"
2895+
class B(StrEnum):
2896+
a = "b"
2897+
b = "a"
2898+
class C(int, Enum):
2899+
a = 0
2900+
b = 1
2901+
class D(IntEnum):
2902+
a = 0
2903+
b = 1
2904+
2905+
a: A
2906+
if a == A.a:
2907+
reveal_type(a) # N: Revealed type is "Literal[__main__.A.a]"
2908+
else:
2909+
reveal_type(a) # N: Revealed type is "Literal[__main__.A.b]"
2910+
2911+
if a == "a":
2912+
reveal_type(a) # N: Revealed type is "__main__.A"
2913+
else:
2914+
reveal_type(a) # N: Revealed type is "__main__.A"
2915+
2916+
if a == "c":
2917+
reveal_type(a) # N: Revealed type is "__main__.A"
2918+
else:
2919+
reveal_type(a) # N: Revealed type is "__main__.A"
2920+
2921+
if a == 0: # E: Non-overlapping equality check (left operand type: "A", right operand type: "Literal[0]")
2922+
reveal_type(a)
2923+
else:
2924+
reveal_type(a) # N: Revealed type is "__main__.A"
2925+
2926+
b: B
2927+
if b == B.a:
2928+
reveal_type(b) # N: Revealed type is "Literal[__main__.B.a]"
2929+
else:
2930+
reveal_type(b) # N: Revealed type is "Literal[__main__.B.b]"
2931+
2932+
if b == "a":
2933+
reveal_type(b) # N: Revealed type is "__main__.B"
2934+
else:
2935+
reveal_type(b) # N: Revealed type is "__main__.B"
2936+
2937+
if b == "c":
2938+
reveal_type(b) # N: Revealed type is "__main__.B"
2939+
else:
2940+
reveal_type(b) # N: Revealed type is "__main__.B"
2941+
2942+
if b == 0: # E: Non-overlapping equality check (left operand type: "B", right operand type: "Literal[0]")
2943+
reveal_type(b)
2944+
else:
2945+
reveal_type(b) # N: Revealed type is "__main__.B"
2946+
2947+
c: C
2948+
if c == C.a:
2949+
reveal_type(c) # N: Revealed type is "Literal[__main__.C.a]"
2950+
else:
2951+
reveal_type(c) # N: Revealed type is "Literal[__main__.C.b]"
2952+
2953+
if c == 0:
2954+
reveal_type(c) # N: Revealed type is "__main__.C"
2955+
else:
2956+
reveal_type(c) # N: Revealed type is "__main__.C"
2957+
2958+
if c == 2:
2959+
reveal_type(c) # N: Revealed type is "__main__.C"
2960+
else:
2961+
reveal_type(c) # N: Revealed type is "__main__.C"
2962+
2963+
if c == "a": # E: Non-overlapping equality check (left operand type: "C", right operand type: "Literal['a']")
2964+
reveal_type(c)
2965+
else:
2966+
reveal_type(c) # N: Revealed type is "__main__.C"
2967+
2968+
d: D
2969+
if d == D.a:
2970+
reveal_type(d) # N: Revealed type is "Literal[__main__.D.a]"
2971+
else:
2972+
reveal_type(d) # N: Revealed type is "Literal[__main__.D.b]"
2973+
2974+
if d == 0:
2975+
reveal_type(d) # N: Revealed type is "__main__.D"
2976+
else:
2977+
reveal_type(d) # N: Revealed type is "__main__.D"
2978+
2979+
if d == 2:
2980+
reveal_type(d) # N: Revealed type is "__main__.D"
2981+
else:
2982+
reveal_type(d) # N: Revealed type is "__main__.D"
2983+
2984+
if d == "a": # E: Non-overlapping equality check (left operand type: "D", right operand type: "Literal['a']")
2985+
reveal_type(d)
2986+
else:
2987+
reveal_type(d) # N: Revealed type is "__main__.D"
2988+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)