|
29 | 29 | Type, |
30 | 30 | get_proper_type, |
31 | 31 | is_named_instance, |
| 32 | + UnionType, |
| 33 | + LiteralType, |
32 | 34 | ) |
33 | 35 |
|
34 | 36 | ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | { |
@@ -61,10 +63,9 @@ def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type: |
61 | 63 | literal_type = LiteralType(enum_field_name, fallback=str_type) |
62 | 64 | return str_type.copy_modified(last_known_value=literal_type) |
63 | 65 |
|
64 | | - # Or `field: SomeEnum = SomeEnum.field; field.name` case: |
65 | | - if not isinstance(ctx.type, Instance) or not ctx.type.type.is_enum: |
66 | | - return ctx.default_attr_type |
67 | | - enum_names = ctx.type.type.enum_members |
| 66 | + # Or `field: SomeEnum = SomeEnum.field; field.name` case, |
| 67 | + # Or `field: Literal[Some.A, Some.B]; field.name` case: |
| 68 | + enum_names = _extract_enum_names_from_type(ctx.type) or _extract_enum_names_from_literal_union(ctx.type) |
68 | 69 | if enum_names: |
69 | 70 | str_type = ctx.api.named_generic_type("builtins.str", []) |
70 | 71 | return make_simplified_union( |
@@ -296,3 +297,27 @@ def _extract_underlying_field_name(typ: Type) -> str | None: |
296 | 297 | # as a string. |
297 | 298 | assert isinstance(underlying_literal.value, str) |
298 | 299 | return underlying_literal.value |
| 300 | + |
| 301 | + |
| 302 | +def _extract_enum_names_from_type(typ: ProperType) -> list[str] | None: |
| 303 | + if not isinstance(typ, Instance) or not typ.type.is_enum: |
| 304 | + return None |
| 305 | + return typ.type.enum_members |
| 306 | + |
| 307 | + |
| 308 | +def _extract_enum_names_from_literal_union(typ: ProperType) -> list[str] | None: |
| 309 | + if not isinstance(typ, UnionType): |
| 310 | + return None |
| 311 | + |
| 312 | + names = [] |
| 313 | + for item in typ.relevant_items(): |
| 314 | + pitem = get_proper_type(item) |
| 315 | + if isinstance(pitem, Instance) and pitem.last_known_value and pitem.type.is_enum: |
| 316 | + assert isinstance(pitem.last_known_value.value, str) |
| 317 | + names.append(pitem.last_known_value.value) |
| 318 | + elif isinstance(pitem, LiteralType): |
| 319 | + assert isinstance(pitem.value, str) |
| 320 | + names.append(pitem.value) |
| 321 | + else: |
| 322 | + return None |
| 323 | + return names |
0 commit comments