@@ -1784,6 +1784,30 @@ def Elem(*args):
17841784
17851785 Union [Elem , str ] # Nor should this
17861786
1787+ def test_union_of_literals (self ):
1788+ self .assertEqual (Union [Literal [1 ], Literal [2 ]].__args__ ,
1789+ (Literal [1 ], Literal [2 ]))
1790+ self .assertEqual (Union [Literal [1 ], Literal [1 ]],
1791+ Literal [1 ])
1792+
1793+ self .assertEqual (Union [Literal [False ], Literal [0 ]].__args__ ,
1794+ (Literal [False ], Literal [0 ]))
1795+ self .assertEqual (Union [Literal [True ], Literal [1 ]].__args__ ,
1796+ (Literal [True ], Literal [1 ]))
1797+
1798+ import enum
1799+ class Ints (enum .IntEnum ):
1800+ A = 0
1801+ B = 1
1802+
1803+ self .assertEqual (Union [Literal [Ints .A ], Literal [Ints .B ]].__args__ ,
1804+ (Literal [Ints .A ], Literal [Ints .B ]))
1805+
1806+ self .assertEqual (Union [Literal [0 ], Literal [Ints .A ], Literal [False ]].__args__ ,
1807+ (Literal [0 ], Literal [Ints .A ], Literal [False ]))
1808+ self .assertEqual (Union [Literal [1 ], Literal [Ints .B ], Literal [True ]].__args__ ,
1809+ (Literal [1 ], Literal [Ints .B ], Literal [True ]))
1810+
17871811
17881812class TupleTests (BaseTestCase ):
17891813
@@ -2151,6 +2175,13 @@ def test_basics(self):
21512175 Literal [Literal [1 , 2 ], Literal [4 , 5 ]]
21522176 Literal [b"foo" , u"bar" ]
21532177
2178+ def test_enum (self ):
2179+ import enum
2180+ class My (enum .Enum ):
2181+ A = 'A'
2182+
2183+ self .assertEqual (Literal [My .A ].__args__ , (My .A ,))
2184+
21542185 def test_illegal_parameters_do_not_raise_runtime_errors (self ):
21552186 # Type checkers should reject these types, but we do not
21562187 # raise errors at runtime to maintain maximum flexibility.
@@ -2240,6 +2271,20 @@ def test_flatten(self):
22402271 self .assertEqual (l , Literal [1 , 2 , 3 ])
22412272 self .assertEqual (l .__args__ , (1 , 2 , 3 ))
22422273
2274+ def test_does_not_flatten_enum (self ):
2275+ import enum
2276+ class Ints (enum .IntEnum ):
2277+ A = 1
2278+ B = 2
2279+
2280+ l = Literal [
2281+ Literal [Ints .A ],
2282+ Literal [Ints .B ],
2283+ Literal [1 ],
2284+ Literal [2 ],
2285+ ]
2286+ self .assertEqual (l .__args__ , (Ints .A , Ints .B , 1 , 2 ))
2287+
22432288
22442289XK = TypeVar ('XK' , str , bytes )
22452290XV = TypeVar ('XV' )
0 commit comments