@@ -1789,6 +1789,30 @@ def Elem(*args):
17891789
17901790 Union [Elem , str ] # Nor should this
17911791
1792+ def test_union_of_literals (self ):
1793+ self .assertEqual (Union [Literal [1 ], Literal [2 ]].__args__ ,
1794+ (Literal [1 ], Literal [2 ]))
1795+ self .assertEqual (Union [Literal [1 ], Literal [1 ]],
1796+ Literal [1 ])
1797+
1798+ self .assertEqual (Union [Literal [False ], Literal [0 ]].__args__ ,
1799+ (Literal [False ], Literal [0 ]))
1800+ self .assertEqual (Union [Literal [True ], Literal [1 ]].__args__ ,
1801+ (Literal [True ], Literal [1 ]))
1802+
1803+ import enum
1804+ class Ints (enum .IntEnum ):
1805+ A = 0
1806+ B = 1
1807+
1808+ self .assertEqual (Union [Literal [Ints .A ], Literal [Ints .B ]].__args__ ,
1809+ (Literal [Ints .A ], Literal [Ints .B ]))
1810+
1811+ self .assertEqual (Union [Literal [0 ], Literal [Ints .A ], Literal [False ]].__args__ ,
1812+ (Literal [0 ], Literal [Ints .A ], Literal [False ]))
1813+ self .assertEqual (Union [Literal [1 ], Literal [Ints .B ], Literal [True ]].__args__ ,
1814+ (Literal [1 ], Literal [Ints .B ], Literal [True ]))
1815+
17921816
17931817class TupleTests (BaseTestCase ):
17941818
@@ -2156,6 +2180,13 @@ def test_basics(self):
21562180 Literal [Literal [1 , 2 ], Literal [4 , 5 ]]
21572181 Literal [b"foo" , u"bar" ]
21582182
2183+ def test_enum (self ):
2184+ import enum
2185+ class My (enum .Enum ):
2186+ A = 'A'
2187+
2188+ self .assertEqual (Literal [My .A ].__args__ , (My .A ,))
2189+
21592190 def test_illegal_parameters_do_not_raise_runtime_errors (self ):
21602191 # Type checkers should reject these types, but we do not
21612192 # raise errors at runtime to maintain maximum flexibility.
@@ -2245,6 +2276,20 @@ def test_flatten(self):
22452276 self .assertEqual (l , Literal [1 , 2 , 3 ])
22462277 self .assertEqual (l .__args__ , (1 , 2 , 3 ))
22472278
2279+ def test_does_not_flatten_enum (self ):
2280+ import enum
2281+ class Ints (enum .IntEnum ):
2282+ A = 1
2283+ B = 2
2284+
2285+ l = Literal [
2286+ Literal [Ints .A ],
2287+ Literal [Ints .B ],
2288+ Literal [1 ],
2289+ Literal [2 ],
2290+ ]
2291+ self .assertEqual (l .__args__ , (Ints .A , Ints .B , 1 , 2 ))
2292+
22482293
22492294XK = TypeVar ('XK' , str , bytes )
22502295XV = TypeVar ('XV' )
0 commit comments