Skip to content

Commit dc66019

Browse files
authored
[ty] Expansion of enums into unions of literals (astral-sh#19382)
## Summary Implement expansion of enums into unions of enum literals (and the reverse operation). For the enum below, this allows us to understand that `Color = Literal[Color.RED, Color.GREEN, Color.BLUE]`, or that `Color & ~Literal[Color.RED] = Literal[Color.GREEN, Color.BLUE]`. This helps in exhaustiveness checking, which is why we see some removed `assert_never` false positives. And since exhaustiveness checking also helps with understanding terminal control flow, we also see a few removed `invalid-return-type` and `possibly-unresolved-reference` false positives. This PR also adds expansion of enums in overload resolution and type narrowing constructs. ```py from enum import Enum from typing_extensions import Literal, assert_never from ty_extensions import Intersection, Not, static_assert, is_equivalent_to class Color(Enum): RED = 1 GREEN = 2 BLUE = 3 type Red = Literal[Color.RED] type Green = Literal[Color.GREEN] type Blue = Literal[Color.BLUE] static_assert(is_equivalent_to(Red | Green | Blue, Color)) static_assert(is_equivalent_to(Intersection[Color, Not[Red]], Green | Blue)) def color_name(color: Color) -> str: # no error here (we detect that this can not implicitly return None) if color is Color.RED: return "Red" elif color is Color.GREEN: return "Green" elif color is Color.BLUE: return "Blue" else: assert_never(color) # no error here ``` ## Performance I avoided an initial regression here for large enums, but the `UnionBuilder` and `IntersectionBuilder` parts can certainly still be optimized. We might want to use the same technique that we also use for unions of other literals. I didn't see any problems in our benchmarks so far, so this is not included yet. ## Test Plan Many new Markdown tests
1 parent 926e833 commit dc66019

File tree

19 files changed

+752
-104
lines changed

19 files changed

+752
-104
lines changed

crates/ty_python_semantic/resources/mdtest/attributes.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,12 +2355,13 @@ import enum
23552355

23562356
reveal_type(enum.Enum.__members__) # revealed: MappingProxyType[str, Unknown]
23572357

2358-
class Foo(enum.Enum):
2359-
BAR = 1
2358+
class Answer(enum.Enum):
2359+
NO = 0
2360+
YES = 1
23602361

2361-
reveal_type(Foo.BAR) # revealed: Literal[Foo.BAR]
2362-
reveal_type(Foo.BAR.value) # revealed: Any
2363-
reveal_type(Foo.__members__) # revealed: MappingProxyType[str, Unknown]
2362+
reveal_type(Answer.NO) # revealed: Literal[Answer.NO]
2363+
reveal_type(Answer.NO.value) # revealed: Any
2364+
reveal_type(Answer.__members__) # revealed: MappingProxyType[str, Unknown]
23642365
```
23652366

23662367
## References

crates/ty_python_semantic/resources/mdtest/call/overloads.md

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def _(x: type[A | B]):
369369

370370
### Expanding enums
371371

372+
#### Basic
373+
372374
`overloaded.pyi`:
373375

374376
```pyi
@@ -394,15 +396,106 @@ def f(x: Literal[SomeEnum.C]) -> C: ...
394396
```
395397

396398
```py
399+
from typing import Literal
397400
from overloaded import SomeEnum, A, B, C, f
398401

399-
def _(x: SomeEnum):
402+
def _(x: SomeEnum, y: Literal[SomeEnum.A, SomeEnum.C]):
400403
reveal_type(f(SomeEnum.A)) # revealed: A
401404
reveal_type(f(SomeEnum.B)) # revealed: B
402405
reveal_type(f(SomeEnum.C)) # revealed: C
403-
# TODO: This should not be an error. The return type should be `A | B | C` once enums are expanded
404-
# error: [no-matching-overload]
405-
reveal_type(f(x)) # revealed: Unknown
406+
reveal_type(f(x)) # revealed: A | B | C
407+
reveal_type(f(y)) # revealed: A | C
408+
```
409+
410+
#### Enum with single member
411+
412+
This pattern appears in typeshed. Here, it is used to represent two optional, mutually exclusive
413+
keyword parameters:
414+
415+
`overloaded.pyi`:
416+
417+
```pyi
418+
from enum import Enum, auto
419+
from typing import overload, Literal
420+
421+
class Missing(Enum):
422+
Value = auto()
423+
424+
class OnlyASpecified: ...
425+
class OnlyBSpecified: ...
426+
class BothMissing: ...
427+
428+
@overload
429+
def f(*, a: int, b: Literal[Missing.Value] = ...) -> OnlyASpecified: ...
430+
@overload
431+
def f(*, a: Literal[Missing.Value] = ..., b: int) -> OnlyBSpecified: ...
432+
@overload
433+
def f(*, a: Literal[Missing.Value] = ..., b: Literal[Missing.Value] = ...) -> BothMissing: ...
434+
```
435+
436+
```py
437+
from typing import Literal
438+
from overloaded import f, Missing
439+
440+
reveal_type(f()) # revealed: BothMissing
441+
reveal_type(f(a=0)) # revealed: OnlyASpecified
442+
reveal_type(f(b=0)) # revealed: OnlyBSpecified
443+
444+
f(a=0, b=0) # error: [no-matching-overload]
445+
446+
def _(missing: Literal[Missing.Value], missing_or_present: Literal[Missing.Value] | int):
447+
reveal_type(f(a=missing, b=missing)) # revealed: BothMissing
448+
reveal_type(f(a=missing)) # revealed: BothMissing
449+
reveal_type(f(b=missing)) # revealed: BothMissing
450+
reveal_type(f(a=0, b=missing)) # revealed: OnlyASpecified
451+
reveal_type(f(a=missing, b=0)) # revealed: OnlyBSpecified
452+
453+
reveal_type(f(a=missing_or_present)) # revealed: BothMissing | OnlyASpecified
454+
reveal_type(f(b=missing_or_present)) # revealed: BothMissing | OnlyBSpecified
455+
456+
# Here, both could be present, so this should be an error
457+
f(a=missing_or_present, b=missing_or_present) # error: [no-matching-overload]
458+
```
459+
460+
#### Enum subclass without members
461+
462+
An `Enum` subclass without members should *not* be expanded:
463+
464+
`overloaded.pyi`:
465+
466+
```pyi
467+
from enum import Enum
468+
from typing import overload, Literal
469+
470+
class MyEnumSubclass(Enum):
471+
pass
472+
473+
class ActualEnum(MyEnumSubclass):
474+
A = 1
475+
B = 2
476+
477+
class OnlyA: ...
478+
class OnlyB: ...
479+
class Both: ...
480+
481+
@overload
482+
def f(x: Literal[ActualEnum.A]) -> OnlyA: ...
483+
@overload
484+
def f(x: Literal[ActualEnum.B]) -> OnlyB: ...
485+
@overload
486+
def f(x: ActualEnum) -> Both: ...
487+
@overload
488+
def f(x: MyEnumSubclass) -> MyEnumSubclass: ...
489+
```
490+
491+
```py
492+
from overloaded import MyEnumSubclass, ActualEnum, f
493+
494+
def _(actual_enum: ActualEnum, my_enum_instance: MyEnumSubclass):
495+
reveal_type(f(actual_enum)) # revealed: Both
496+
reveal_type(f(ActualEnum.A)) # revealed: OnlyA
497+
reveal_type(f(ActualEnum.B)) # revealed: OnlyB
498+
reveal_type(f(my_enum_instance)) # revealed: MyEnumSubclass
406499
```
407500

408501
### No matching overloads

crates/ty_python_semantic/resources/mdtest/enums.md

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,111 @@ To do: <https://typing.python.org/en/latest/spec/enums.html#enum-definition>
570570

571571
## Exhaustiveness checking
572572

573-
To do
573+
## `if` statements
574+
575+
```py
576+
from enum import Enum
577+
from typing_extensions import assert_never
578+
579+
class Color(Enum):
580+
RED = 1
581+
GREEN = 2
582+
BLUE = 3
583+
584+
def color_name(color: Color) -> str:
585+
if color is Color.RED:
586+
return "Red"
587+
elif color is Color.GREEN:
588+
return "Green"
589+
elif color is Color.BLUE:
590+
return "Blue"
591+
else:
592+
assert_never(color)
593+
594+
# No `invalid-return-type` error here because the implicit `else` branch is detected as unreachable:
595+
def color_name_without_assertion(color: Color) -> str:
596+
if color is Color.RED:
597+
return "Red"
598+
elif color is Color.GREEN:
599+
return "Green"
600+
elif color is Color.BLUE:
601+
return "Blue"
602+
603+
def color_name_misses_one_variant(color: Color) -> str:
604+
if color is Color.RED:
605+
return "Red"
606+
elif color is Color.GREEN:
607+
return "Green"
608+
else:
609+
assert_never(color) # error: [type-assertion-failure] "Argument does not have asserted type `Never`"
610+
611+
class Singleton(Enum):
612+
VALUE = 1
613+
614+
def singleton_check(value: Singleton) -> str:
615+
if value is Singleton.VALUE:
616+
return "Singleton value"
617+
else:
618+
assert_never(value)
619+
```
620+
621+
## `match` statements
622+
623+
```toml
624+
[environment]
625+
python-version = "3.10"
626+
```
627+
628+
```py
629+
from enum import Enum
630+
from typing_extensions import assert_never
631+
632+
class Color(Enum):
633+
RED = 1
634+
GREEN = 2
635+
BLUE = 3
636+
637+
def color_name(color: Color) -> str:
638+
match color:
639+
case Color.RED:
640+
return "Red"
641+
case Color.GREEN:
642+
return "Green"
643+
case Color.BLUE:
644+
return "Blue"
645+
case _:
646+
assert_never(color)
647+
648+
# TODO: this should not be an error, see https://github.com/astral-sh/ty/issues/99#issuecomment-2983054488
649+
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `str`"
650+
def color_name_without_assertion(color: Color) -> str:
651+
match color:
652+
case Color.RED:
653+
return "Red"
654+
case Color.GREEN:
655+
return "Green"
656+
case Color.BLUE:
657+
return "Blue"
658+
659+
def color_name_misses_one_variant(color: Color) -> str:
660+
match color:
661+
case Color.RED:
662+
return "Red"
663+
case Color.GREEN:
664+
return "Green"
665+
case _:
666+
assert_never(color) # error: [type-assertion-failure] "Argument does not have asserted type `Never`"
667+
668+
class Singleton(Enum):
669+
VALUE = 1
670+
671+
def singleton_check(value: Singleton) -> str:
672+
match value:
673+
case Singleton.VALUE:
674+
return "Singleton value"
675+
case _:
676+
assert_never(value)
677+
```
574678

575679
## References
576680

crates/ty_python_semantic/resources/mdtest/intersection_types.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,65 @@ def f(
763763
reveal_type(j) # revealed: Unknown & Literal[""]
764764
```
765765

766+
## Simplifications involving enums and enum literals
767+
768+
```toml
769+
[environment]
770+
python-version = "3.12"
771+
```
772+
773+
```py
774+
from ty_extensions import Intersection, Not
775+
from typing import Literal
776+
from enum import Enum
777+
778+
class Color(Enum):
779+
RED = "red"
780+
GREEN = "green"
781+
BLUE = "blue"
782+
783+
type Red = Literal[Color.RED]
784+
type Green = Literal[Color.GREEN]
785+
type Blue = Literal[Color.BLUE]
786+
787+
def f(
788+
a: Intersection[Color, Red],
789+
b: Intersection[Color, Not[Red]],
790+
c: Intersection[Color, Not[Red | Green]],
791+
d: Intersection[Color, Not[Red | Green | Blue]],
792+
e: Intersection[Red, Not[Color]],
793+
f: Intersection[Red | Green, Not[Color]],
794+
g: Intersection[Not[Red], Color],
795+
h: Intersection[Red, Green],
796+
i: Intersection[Red | Green, Green | Blue],
797+
):
798+
reveal_type(a) # revealed: Literal[Color.RED]
799+
reveal_type(b) # revealed: Literal[Color.GREEN, Color.BLUE]
800+
reveal_type(c) # revealed: Literal[Color.BLUE]
801+
reveal_type(d) # revealed: Never
802+
reveal_type(e) # revealed: Never
803+
reveal_type(f) # revealed: Never
804+
reveal_type(g) # revealed: Literal[Color.GREEN, Color.BLUE]
805+
reveal_type(h) # revealed: Never
806+
reveal_type(i) # revealed: Literal[Color.GREEN]
807+
808+
class Single(Enum):
809+
VALUE = 0
810+
811+
def g(
812+
a: Intersection[Single, Literal[Single.VALUE]],
813+
b: Intersection[Single, Not[Literal[Single.VALUE]]],
814+
c: Intersection[Not[Literal[Single.VALUE]], Single],
815+
d: Intersection[Single, Not[Single]],
816+
e: Intersection[Single | int, Not[Single]],
817+
):
818+
reveal_type(a) # revealed: Single
819+
reveal_type(b) # revealed: Never
820+
reveal_type(c) # revealed: Never
821+
reveal_type(d) # revealed: Never
822+
reveal_type(e) # revealed: int
823+
```
824+
766825
## Addition of a type to an intersection with many non-disjoint types
767826

768827
This slightly strange-looking test is a regression test for a mistake that was nearly made in a PR:

crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Narrowing for `!=` conditionals
1+
# Narrowing for `!=` and `==` conditionals
22

33
## `x != None`
44

@@ -22,6 +22,12 @@ def _(x: bool):
2222
reveal_type(x) # revealed: Literal[True]
2323
else:
2424
reveal_type(x) # revealed: Literal[False]
25+
26+
def _(x: bool):
27+
if x == False:
28+
reveal_type(x) # revealed: Literal[False]
29+
else:
30+
reveal_type(x) # revealed: Literal[True]
2531
```
2632

2733
### Enums
@@ -35,11 +41,31 @@ class Answer(Enum):
3541

3642
def _(answer: Answer):
3743
if answer != Answer.NO:
38-
# TODO: This should be simplified to `Literal[Answer.YES]`
39-
reveal_type(answer) # revealed: Answer & ~Literal[Answer.NO]
44+
reveal_type(answer) # revealed: Literal[Answer.YES]
45+
else:
46+
reveal_type(answer) # revealed: Literal[Answer.NO]
47+
48+
def _(answer: Answer):
49+
if answer == Answer.NO:
50+
reveal_type(answer) # revealed: Literal[Answer.NO]
51+
else:
52+
reveal_type(answer) # revealed: Literal[Answer.YES]
53+
54+
class Single(Enum):
55+
VALUE = 1
56+
57+
def _(x: Single | int):
58+
if x != Single.VALUE:
59+
reveal_type(x) # revealed: int
60+
else:
61+
# `int` is not eliminated here because there could be subclasses of `int` with custom `__eq__`/`__ne__` methods
62+
reveal_type(x) # revealed: Single | int
63+
64+
def _(x: Single | int):
65+
if x == Single.VALUE:
66+
reveal_type(x) # revealed: Single | int
4067
else:
41-
# TODO: This should be `Literal[Answer.NO]`
42-
reveal_type(answer) # revealed: Answer
68+
reveal_type(x) # revealed: int
4369
```
4470

4571
This narrowing behavior is only safe if the enum has no custom `__eq__`/`__ne__` method:

crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,16 @@ def _(answer: Answer):
7878
if answer is Answer.NO:
7979
reveal_type(answer) # revealed: Literal[Answer.NO]
8080
else:
81-
# TODO: This should be `Literal[Answer.YES]`
82-
reveal_type(answer) # revealed: Answer & ~Literal[Answer.NO]
81+
reveal_type(answer) # revealed: Literal[Answer.YES]
82+
83+
class Single(Enum):
84+
VALUE = 1
85+
86+
def _(x: Single | int):
87+
if x is Single.VALUE:
88+
reveal_type(x) # revealed: Single
89+
else:
90+
reveal_type(x) # revealed: int
8391
```
8492

8593
## `is` for `EllipsisType` (Python 3.10+)

0 commit comments

Comments
 (0)