Skip to content

Commit b09cc84

Browse files
committed
Fix pydantic annotation handling for Callable
Signed-off-by: Pascal Tomecek <pascal.tomecek@cubistsystematic.com>
1 parent a436949 commit b09cc84

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

csp/impl/types/pydantic_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ def adjust_annotations(
129129
annotation.typ, top_level=False, in_ts=False, make_optional=False, forced_tvars=forced_tvars
130130
)
131131
)
132+
elif isinstance(annotation, list): # Handle list, i.e. in Callable[[Annotation], Any]
133+
return [
134+
adjust_annotations(a, top_level=False, in_ts=in_ts, make_optional=False, forced_tvars=forced_tvars)
135+
for a in annotation
136+
]
132137

133138
if type(annotation) is ForwardRef:
134139
if in_ts:

csp/tests/impl/types/test_pydantic_types.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin
2+
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin
33
from unittest import TestCase
44

55
import csp
@@ -24,6 +24,10 @@ def assertAnnotationsEqual(self, annotation1, annotation2):
2424
if origin1 is None:
2525
if isinstance(annotation1, TypeVar) and isinstance(annotation2, TypeVar):
2626
self.assertEqual(annotation1.__name__, annotation2.__name__)
27+
elif isinstance(annotation1, list) and isinstance(annotation2, list):
28+
self.assertEqual(len(annotation1), len(annotation2))
29+
for item1, item2 in zip(annotation1, annotation2):
30+
self.assertAnnotationsEqual(item1, item2)
2731
elif issubclass(annotation1, OutputBasket) and issubclass(annotation2, OutputBasket):
2832
self.assertAnnotationsEqual(annotation1.typ, annotation2.typ)
2933
else:
@@ -56,6 +60,12 @@ def test_tvar_container(self):
5660
self.assertAnnotationsEqual(adjust_annotations(MyGeneric["T"]), MyGeneric[CspTypeVar[T]])
5761
self.assertAnnotationsEqual(adjust_annotations(MyGeneric[T]), MyGeneric[CspTypeVar[T]])
5862

63+
def test_tvar_callable(self):
64+
self.assertAnnotationsEqual(adjust_annotations(Callable[["T"], Any]), Callable[[CspTypeVar[T]], Any])
65+
self.assertAnnotationsEqual(
66+
adjust_annotations(Callable[["K", "K"], "T"]), Callable[[CspTypeVar[K], CspTypeVar[K]], CspTypeVar[T]]
67+
)
68+
5969
def test_tvar_ts_of_container(self):
6070
self.assertAnnotationsEqual(adjust_annotations(ts["T"]), ts[CspTypeVarType[T]])
6171
self.assertAnnotationsEqual(adjust_annotations(ts["~T"]), ts[CspTypeVarType[TypeVar("~T")]])
@@ -77,6 +87,15 @@ def test_tvar_ts_of_container(self):
7787
adjust_annotations(ts[Union[K, T]]), ts[Union[CspTypeVarType[K], CspTypeVarType[T]]]
7888
)
7989

90+
def test_tvar_ts_of_callable(self):
91+
self.assertAnnotationsEqual(
92+
adjust_annotations(ts[Callable[["T"], Any]]), ts[Callable[[CspTypeVarType[T]], Any]]
93+
)
94+
self.assertAnnotationsEqual(
95+
adjust_annotations(ts[Callable[["K", "K"], "T"]]),
96+
ts[Callable[[CspTypeVarType[K], CspTypeVarType[K]], CspTypeVarType[T]]],
97+
)
98+
8099
def test_tvar_container_of_ts(self):
81100
self.assertAnnotationsEqual(adjust_annotations(List[ts["T"]]), List[ts[CspTypeVarType[T]]])
82101
self.assertAnnotationsEqual(adjust_annotations(List[ts[T]]), List[ts[CspTypeVarType[T]]])

0 commit comments

Comments
 (0)