Skip to content

Commit 06076df

Browse files
authored
Merge pull request #136 from Point72/tkp/sc
Allow inheritance from callablegeneric type
2 parents f144bf0 + 0775016 commit 06076df

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

ccflow/callable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
556556

557557
if isinstance(m, str):
558558
m = resolve_str(m)
559+
if isinstance(m, dict):
560+
m = handler(m)
559561
# Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/
560562
if not isinstance(m, CallableModel):
561563
raise ValueError(f"{m} is not a CallableModel: {type(m)}")

ccflow/tests/test_callable.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ContextBase,
1010
ContextType,
1111
Flow,
12+
GenericResult,
1213
GraphDepList,
1314
MetaData,
1415
ModelRegistry,
@@ -316,6 +317,16 @@ def test_wrapper_reference(self):
316317
self.assertEqual(w.context_type, m.context_type)
317318
self.assertEqual(w.result_type, m.result_type)
318319

320+
def test_override_in_subclass(self):
321+
class MyCallable(CallableModelGenericType[NullContext, GenericResult]):
322+
@Flow.call
323+
def __call__(self, context: NullContext) -> GenericResult[int]:
324+
return GenericResult[int](value=42)
325+
326+
m = MyCallable()
327+
self.assertEqual(m.context_type, NullContext)
328+
self.assertEqual(m.result_type, GenericResult[int])
329+
319330

320331
class TestCallableModelDeps(TestCase):
321332
def test_basic(self):

0 commit comments

Comments
 (0)