Skip to content

Commit 97bdaff

Browse files
committed
Address review
1 parent efc6a0f commit 97bdaff

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

mypy/plugins/enums.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
Type,
3030
get_proper_type,
3131
is_named_instance,
32+
UnionType,
33+
LiteralType,
3234
)
3335

3436
ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | {
@@ -61,10 +63,9 @@ def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type:
6163
literal_type = LiteralType(enum_field_name, fallback=str_type)
6264
return str_type.copy_modified(last_known_value=literal_type)
6365

64-
# Or `field: SomeEnum = SomeEnum.field; field.name` case:
65-
if not isinstance(ctx.type, Instance) or not ctx.type.type.is_enum:
66-
return ctx.default_attr_type
67-
enum_names = ctx.type.type.enum_members
66+
# Or `field: SomeEnum = SomeEnum.field; field.name` case,
67+
# Or `field: Literal[Some.A, Some.B]; field.name` case:
68+
enum_names = _extract_enum_names_from_type(ctx.type) or _extract_enum_names_from_literal_union(ctx.type)
6869
if enum_names:
6970
str_type = ctx.api.named_generic_type("builtins.str", [])
7071
return make_simplified_union(
@@ -296,3 +297,27 @@ def _extract_underlying_field_name(typ: Type) -> str | None:
296297
# as a string.
297298
assert isinstance(underlying_literal.value, str)
298299
return underlying_literal.value
300+
301+
302+
def _extract_enum_names_from_type(typ: ProperType) -> list[str] | None:
303+
if not isinstance(typ, Instance) or not typ.type.is_enum:
304+
return None
305+
return typ.type.enum_members
306+
307+
308+
def _extract_enum_names_from_literal_union(typ: ProperType) -> list[str] | None:
309+
if not isinstance(typ, UnionType):
310+
return None
311+
312+
names = []
313+
for item in typ.relevant_items():
314+
pitem = get_proper_type(item)
315+
if isinstance(pitem, Instance) and pitem.last_known_value and pitem.type.is_enum:
316+
assert isinstance(pitem.last_known_value.value, str)
317+
names.append(pitem.last_known_value.value)
318+
elif isinstance(pitem, LiteralType):
319+
assert isinstance(pitem.value, str)
320+
names.append(pitem.value)
321+
else:
322+
return None
323+
return names

test-data/unit/check-enum.test

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,18 @@ e: Empty
126126
reveal_type(e.name) # N: Revealed type is "builtins.str"
127127
[builtins fixtures/tuple.pyi]
128128

129+
[case testEnumNameValueOnUnionOfLiteral]
130+
from enum import Enum
131+
from typing import Literal
132+
class Colors(Enum):
133+
red = 1
134+
blue = 2
135+
green = 3
136+
137+
color: Literal[Colors.red, Colors.blue]
138+
reveal_type(color.name) # N: Revealed type is "Union[Literal['red'], Literal['blue']]"
139+
[builtins fixtures/tuple.pyi]
140+
129141
[case testEnumValueExtended]
130142
from enum import Enum
131143
class Truth(Enum):

0 commit comments

Comments
 (0)