11import sys
2- from typing import Dict , Generic , List , Optional , Type , TypeVar , Union , get_args , get_origin
2+ from inspect import isclass
3+ from typing import Any , Callable , Dict , Generic , List , Optional , Type , TypeVar , Union , get_args , get_origin
34from unittest import TestCase
45
56import csp
@@ -24,7 +25,16 @@ def assertAnnotationsEqual(self, annotation1, annotation2):
2425 if origin1 is None :
2526 if isinstance (annotation1 , TypeVar ) and isinstance (annotation2 , TypeVar ):
2627 self .assertEqual (annotation1 .__name__ , annotation2 .__name__ )
27- elif issubclass (annotation1 , OutputBasket ) and issubclass (annotation2 , OutputBasket ):
28+ elif isinstance (annotation1 , list ) and isinstance (annotation2 , list ):
29+ self .assertEqual (len (annotation1 ), len (annotation2 ))
30+ for item1 , item2 in zip (annotation1 , annotation2 ):
31+ self .assertAnnotationsEqual (item1 , item2 )
32+ elif (
33+ isclass (annotation1 )
34+ and issubclass (annotation1 , OutputBasket )
35+ and isclass (annotation2 )
36+ and issubclass (annotation2 , OutputBasket )
37+ ):
2838 self .assertAnnotationsEqual (annotation1 .typ , annotation2 .typ )
2939 else :
3040 self .assertEqual (annotation1 , annotation2 )
@@ -56,6 +66,12 @@ def test_tvar_container(self):
5666 self .assertAnnotationsEqual (adjust_annotations (MyGeneric ["T" ]), MyGeneric [CspTypeVar [T ]])
5767 self .assertAnnotationsEqual (adjust_annotations (MyGeneric [T ]), MyGeneric [CspTypeVar [T ]])
5868
69+ def test_tvar_callable (self ):
70+ self .assertAnnotationsEqual (adjust_annotations (Callable [["T" ], Any ]), Callable [[CspTypeVar [T ]], Any ])
71+ self .assertAnnotationsEqual (
72+ adjust_annotations (Callable [["K" , "K" ], "T" ]), Callable [[CspTypeVar [K ], CspTypeVar [K ]], CspTypeVar [T ]]
73+ )
74+
5975 def test_tvar_ts_of_container (self ):
6076 self .assertAnnotationsEqual (adjust_annotations (ts ["T" ]), ts [CspTypeVarType [T ]])
6177 self .assertAnnotationsEqual (adjust_annotations (ts ["~T" ]), ts [CspTypeVarType [TypeVar ("~T" )]])
@@ -77,6 +93,15 @@ def test_tvar_ts_of_container(self):
7793 adjust_annotations (ts [Union [K , T ]]), ts [Union [CspTypeVarType [K ], CspTypeVarType [T ]]]
7894 )
7995
96+ def test_tvar_ts_of_callable (self ):
97+ self .assertAnnotationsEqual (
98+ adjust_annotations (ts [Callable [["T" ], Any ]]), ts [Callable [[CspTypeVarType [T ]], Any ]]
99+ )
100+ self .assertAnnotationsEqual (
101+ adjust_annotations (ts [Callable [["K" , "K" ], "T" ]]),
102+ ts [Callable [[CspTypeVarType [K ], CspTypeVarType [K ]], CspTypeVarType [T ]]],
103+ )
104+
80105 def test_tvar_container_of_ts (self ):
81106 self .assertAnnotationsEqual (adjust_annotations (List [ts ["T" ]]), List [ts [CspTypeVarType [T ]]])
82107 self .assertAnnotationsEqual (adjust_annotations (List [ts [T ]]), List [ts [CspTypeVarType [T ]]])
0 commit comments