Skip to content

Commit 8f9ded8

Browse files
authored
Fix pydantic annotation handling for Callable (#455)
Signed-off-by: Pascal Tomecek <pascal.tomecek@cubistsystematic.com>
1 parent c938b85 commit 8f9ded8

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

csp/impl/types/pydantic_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ def adjust_annotations(
129129
annotation.typ, top_level=False, in_ts=False, make_optional=False, forced_tvars=forced_tvars
130130
)
131131
)
132+
elif isinstance(annotation, list): # Handle list, i.e. in Callable[[Annotation], Any]
133+
return [
134+
adjust_annotations(a, top_level=False, in_ts=in_ts, make_optional=False, forced_tvars=forced_tvars)
135+
for a in annotation
136+
]
132137

133138
if type(annotation) is ForwardRef:
134139
if in_ts:

csp/tests/impl/types/test_pydantic_types.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
2-
from typing import Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin
2+
from inspect import isclass
3+
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin
34
from unittest import TestCase
45

56
import csp
@@ -24,7 +25,16 @@ def assertAnnotationsEqual(self, annotation1, annotation2):
2425
if origin1 is None:
2526
if isinstance(annotation1, TypeVar) and isinstance(annotation2, TypeVar):
2627
self.assertEqual(annotation1.__name__, annotation2.__name__)
27-
elif issubclass(annotation1, OutputBasket) and issubclass(annotation2, OutputBasket):
28+
elif isinstance(annotation1, list) and isinstance(annotation2, list):
29+
self.assertEqual(len(annotation1), len(annotation2))
30+
for item1, item2 in zip(annotation1, annotation2):
31+
self.assertAnnotationsEqual(item1, item2)
32+
elif (
33+
isclass(annotation1)
34+
and issubclass(annotation1, OutputBasket)
35+
and isclass(annotation2)
36+
and issubclass(annotation2, OutputBasket)
37+
):
2838
self.assertAnnotationsEqual(annotation1.typ, annotation2.typ)
2939
else:
3040
self.assertEqual(annotation1, annotation2)
@@ -56,6 +66,12 @@ def test_tvar_container(self):
5666
self.assertAnnotationsEqual(adjust_annotations(MyGeneric["T"]), MyGeneric[CspTypeVar[T]])
5767
self.assertAnnotationsEqual(adjust_annotations(MyGeneric[T]), MyGeneric[CspTypeVar[T]])
5868

69+
def test_tvar_callable(self):
70+
self.assertAnnotationsEqual(adjust_annotations(Callable[["T"], Any]), Callable[[CspTypeVar[T]], Any])
71+
self.assertAnnotationsEqual(
72+
adjust_annotations(Callable[["K", "K"], "T"]), Callable[[CspTypeVar[K], CspTypeVar[K]], CspTypeVar[T]]
73+
)
74+
5975
def test_tvar_ts_of_container(self):
6076
self.assertAnnotationsEqual(adjust_annotations(ts["T"]), ts[CspTypeVarType[T]])
6177
self.assertAnnotationsEqual(adjust_annotations(ts["~T"]), ts[CspTypeVarType[TypeVar("~T")]])
@@ -77,6 +93,15 @@ def test_tvar_ts_of_container(self):
7793
adjust_annotations(ts[Union[K, T]]), ts[Union[CspTypeVarType[K], CspTypeVarType[T]]]
7894
)
7995

96+
def test_tvar_ts_of_callable(self):
97+
self.assertAnnotationsEqual(
98+
adjust_annotations(ts[Callable[["T"], Any]]), ts[Callable[[CspTypeVarType[T]], Any]]
99+
)
100+
self.assertAnnotationsEqual(
101+
adjust_annotations(ts[Callable[["K", "K"], "T"]]),
102+
ts[Callable[[CspTypeVarType[K], CspTypeVarType[K]], CspTypeVarType[T]]],
103+
)
104+
80105
def test_tvar_container_of_ts(self):
81106
self.assertAnnotationsEqual(adjust_annotations(List[ts["T"]]), List[ts[CspTypeVarType[T]]])
82107
self.assertAnnotationsEqual(adjust_annotations(List[ts[T]]), List[ts[CspTypeVarType[T]]])

0 commit comments

Comments
 (0)