@@ -10758,6 +10758,69 @@ def test_eq(self):
1075810758 with self .assertWarns (DeprecationWarning ):
1075910759 self .assertNotEqual (int , typing ._UnionGenericAlias )
1076010760
10761+ class MyType :
10762+ pass
10763+
10764+ class test_generic_alias_handling (BaseTestCase ):
10765+ def test_forward_ref (self ):
10766+ fwd_ref = ForwardRef ('MyType' )
10767+ result = _eval_type (fwd_ref , globals (), locals ())
10768+ self .assertIs (result , MyType , f"Expected MyType, got { result } " )
10769+
10770+ def test_generic_alias (self ):
10771+ fwd_ref = ForwardRef ('MyType' )
10772+ generic_list = List [fwd_ref ]
10773+ result = _eval_type (generic_list , globals (), locals ())
10774+ self .assertEqual (result , List [MyType ], f"Expected List[MyType], got { result } " )
10775+
10776+ def test_union (self ):
10777+ fwd_ref_1 = ForwardRef ('MyType' )
10778+ fwd_ref_2 = ForwardRef ('int' )
10779+ union_type = Union [fwd_ref_1 , fwd_ref_2 ]
10780+
10781+ result = _eval_type (union_type , globals (), locals ())
10782+ self .assertEqual (result , Union [MyType , int ], f"Expected Union[MyType, int], got { result } " )
10783+
10784+ def test_recursive_forward_ref (self ):
10785+ recursive_ref = ForwardRef ('RecursiveType' )
10786+ globals ()['RecursiveType' ] = recursive_ref
10787+
10788+ recursive_type = Dict [str , List [recursive_ref ]]
10789+
10790+ result = _eval_type (recursive_type , globals (), locals (), recursive_guard = {recursive_ref })
10791+
10792+ self .assertEqual (result , Dict [str , List [recursive_ref ]], f"Expected Dict[str, List[RecursiveType]], got { result } " )
10793+
10794+ def test_callable_unpacking (self ):
10795+ fwd_ref = ForwardRef ('MyType' )
10796+ callable_type = Callable [[fwd_ref , int ], str ]
10797+ result = _eval_type (callable_type , globals (), locals ())
10798+
10799+ self .assertEqual (result , Callable [[MyType , int ], str ], f"Expected Callable[[MyType, int], str], got { result } " )
10800+
10801+ def test_unpacked_generic (self ):
10802+ fwd_ref = ForwardRef ('MyType' )
10803+ generic_type = Tuple [fwd_ref , int ]
10804+
10805+ result = _eval_type (generic_type , globals (), locals ())
10806+ self .assertEqual (result , Tuple [MyType , int ], f"Expected Tuple[MyType, int], got { result } " )
10807+
10808+ def test_preservation_of_type (self ):
10809+ fwd_ref_1 = ForwardRef ('MyType' )
10810+ fwd_ref_2 = ForwardRef ('int' )
10811+ complex_type = Dict [str , Union [fwd_ref_1 , fwd_ref_2 ]]
10812+
10813+ result = _eval_type (complex_type , globals (), locals ())
10814+ self .assertEqual (result , Dict [str , Union [MyType , int ]], f"Expected Dict[str, Union[MyType, int]], got { result } " )
10815+
10816+ def test_callable_unflattening (self ):
10817+ callable_type = Callable [[int , str ], bool ]
10818+ result = _eval_type (callable_type , globals (), locals (), type_params = ())
10819+ self .assertEqual (result , Callable [[int , str ], bool ], f"Expected Callable[[int, str], bool], got { result } " )
10820+
10821+ callable_type_packed = Callable [[int , str ], bool ] # Correct format for callable
10822+ result = _eval_type (callable_type_packed , globals (), locals (), type_params = ())
10823+ self .assertEqual (result , Callable [[int , str ], bool ], f"Expected Callable[[int, str], bool], got { result } " )
1076110824
1076210825def load_tests (loader , tests , pattern ):
1076310826 import doctest
0 commit comments