Skip to content

Commit df23202

Browse files
committed
allow CallableModelGenericType to be fully fledged callable model base
Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com>
1 parent 0775016 commit df23202

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

ccflow/callable.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,17 @@ class CallableModelGenericType(CallableModel, Generic[ContextType, ResultType]):
550550
context and result type will be validated.
551551
"""
552552

553+
_context_type: ClassVar[Type[ContextType]]
554+
_result_type: ClassVar[Type[ResultType]]
555+
556+
@property
557+
def context_type(self) -> Type[ContextType]:
558+
return self._context_type
559+
560+
@property
561+
def result_type(self) -> Type[ResultType]:
562+
return self._result_type
563+
553564
@model_validator(mode="wrap")
554565
def _validate_callable_model_generic_type(cls, m, handler, info):
555566
from ccflow.base import resolve_str
@@ -561,8 +572,26 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
561572
# Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/
562573
if not isinstance(m, CallableModel):
563574
raise ValueError(f"{m} is not a CallableModel: {type(m)}")
564-
subtypes = cls.__pydantic_generic_metadata__["args"]
565-
if subtypes:
566-
TypeAdapter(Type[subtypes[0]]).validate_python(m.context_type)
567-
TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type)
575+
576+
# Extract the generic types from the class definition
577+
generic_base = None
578+
for base in cls.__mro__[1:]:
579+
if issubclass(base, CallableModelGenericType):
580+
generic_base = base
581+
break
582+
583+
if generic_base and generic_base.__pydantic_generic_metadata__["args"]:
584+
# cls is subclass of generic_base which defines the generic types
585+
# so use these as the context and result types
586+
subtypes = generic_base.__pydantic_generic_metadata__["args"]
587+
if len(subtypes) != 2:
588+
raise ValueError("CallableModelGenericType must have exactly two generic type parameters: ContextType and ResultType")
589+
cls._context_type = subtypes[0]
590+
cls._result_type = subtypes[1]
591+
592+
else:
593+
subtypes = cls.__pydantic_generic_metadata__["args"]
594+
if subtypes:
595+
TypeAdapter(Type[subtypes[0]]).validate_python(m.context_type)
596+
TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type)
568597
return m

ccflow/tests/test_callable.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,27 @@ def test_wrapper_reference(self):
317317
self.assertEqual(w.context_type, m.context_type)
318318
self.assertEqual(w.result_type, m.result_type)
319319

320-
def test_override_in_subclass(self):
321-
class MyCallable(CallableModelGenericType[NullContext, GenericResult]):
320+
def test_use_as_base_class(self):
321+
class MyCallable(CallableModelGenericType[NullContext, GenericResult[int]]):
322322
@Flow.call
323323
def __call__(self, context: NullContext) -> GenericResult[int]:
324324
return GenericResult[int](value=42)
325325

326326
m = MyCallable()
327327
self.assertEqual(m.context_type, NullContext)
328328
self.assertEqual(m.result_type, GenericResult[int])
329+
self.assertEqual(m(NullContext()).value, 42)
330+
331+
def test_use_as_base_class_no_call_annotations(self):
332+
class MyCallable(CallableModelGenericType[NullContext, GenericResult[int]]):
333+
@Flow.call
334+
def __call__(self, context):
335+
return GenericResult[int](value=42)
336+
337+
m = MyCallable()
338+
self.assertEqual(m.context_type, NullContext)
339+
self.assertEqual(m.result_type, GenericResult[int])
340+
self.assertEqual(m(NullContext()).value, 42)
329341

330342

331343
class TestCallableModelDeps(TestCase):

0 commit comments

Comments
 (0)