Skip to content

Commit fcb1716

Browse files
authored
Merge pull request #152 from Point72/tkp/un
Allow for union types in call signature return if separate concerete property
2 parents ca4b202 + cb9d44a commit fcb1716

File tree

2 files changed

+154
-28
lines changed

2 files changed

+154
-28
lines changed

ccflow/callable.py

Lines changed: 89 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,27 @@ def _check_signature(self):
119119

120120
if hasattr(self, "result_type"):
121121
type_call_return = _cached_signature(self.__class__.__call__).return_annotation
122-
if (
123-
not isinstance(type_call_return, TypeVar)
124-
and type_call_return is not Signature.empty
125-
and (not isclass(type_call_return) or not issubclass(type_call_return, self.result_type))
126-
and (not isclass(self.result_type) or not issubclass(self.result_type, type_call_return))
127-
):
122+
123+
# If union, check all types
124+
if get_origin(type_call_return) is Union and get_args(type_call_return):
125+
types_call_return = [t for t in get_args(type_call_return) if t is not type(None)]
126+
else:
127+
types_call_return = [type_call_return]
128+
129+
all_bad = True
130+
for type_call_return in types_call_return:
131+
if (
132+
not isinstance(type_call_return, TypeVar)
133+
and type_call_return is not Signature.empty
134+
and (not isclass(type_call_return) or not issubclass(type_call_return, self.result_type))
135+
and (not isclass(self.result_type) or not issubclass(self.result_type, type_call_return))
136+
):
137+
# Don't invert logic so that we match context above
138+
pass
139+
else:
140+
all_bad = False
141+
142+
if all_bad:
128143
err_msg_type_mismatch = f"The result_type {self.result_type} must match the return type of __call__ {type_call_return}"
129144
raise ValueError(err_msg_type_mismatch)
130145

@@ -251,7 +266,9 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
251266
get_origin(model.context_type) is Union and type(None) in get_args(model.context_type)
252267
):
253268
raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase")
254-
if not isclass(model.result_type) or not issubclass(model.result_type, ResultBase):
269+
if (not isclass(model.result_type) or not issubclass(model.result_type, ResultBase)) and not (
270+
get_origin(model.result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(model.result_type))
271+
):
255272
raise TypeError(f"Result type {model.result_type} must be a subclass of ResultBase")
256273
if self._deps and fn.__name__ != "__deps__":
257274
raise ValueError("Can only apply Flow.deps decorator to __deps__")
@@ -457,7 +474,10 @@ def __call__(self) -> ResultType:
457474
elif hasattr(result, "_lazy_is_delayed"):
458475
object.__setattr__(result, "_lazy_validation_requested", True)
459476
elif hasattr(self.model, "result_type"):
460-
result = self.model.result_type.model_validate(result)
477+
result_type = self.model.result_type
478+
if not isclass(result_type) or not issubclass(result_type, ResultBase):
479+
raise TypeError(f"Model result_type {result_type} is not a subclass of ResultBase")
480+
result = result_type.model_validate(result)
461481

462482
return result
463483
else:
@@ -530,16 +550,20 @@ def context_type(self) -> Type[ContextType]:
530550
if typ is Signature.empty:
531551
raise TypeError("Must either define a type annotation for context on __call__ or implement 'context_type'")
532552

553+
self._check_context_type(typ)
554+
return typ
555+
556+
@staticmethod
557+
def _check_context_type(typ):
533558
# If optional type, extract inner type
534559
if get_origin(typ) is Optional or (get_origin(typ) is Union and type(None) in get_args(typ)):
535-
typ_to_check = [t for t in get_args(typ) if t is not type(None)][0]
560+
type_to_check = [t for t in get_args(typ) if t is not type(None)][0]
536561
else:
537-
typ_to_check = typ
562+
type_to_check = typ
538563

539564
# Ensure subclass of ContextBase
540-
if not isclass(typ_to_check) or not issubclass(typ_to_check, ContextBase):
541-
raise TypeError(f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received {typ_to_check}.")
542-
return typ
565+
if not isclass(type_to_check) or not issubclass(type_to_check, ContextBase):
566+
raise TypeError(f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received {type_to_check}.")
543567

544568
@property
545569
def result_type(self) -> Type[ResultType]:
@@ -551,9 +575,21 @@ def result_type(self) -> Type[ResultType]:
551575
typ = _cached_signature(self.__class__.__call__).return_annotation
552576
if typ is Signature.empty:
553577
raise TypeError("Must either define a return type annotation on __call__ or implement 'result_type'")
578+
579+
self._check_result_type(typ)
580+
return typ
581+
582+
@staticmethod
583+
def _check_result_type(typ):
584+
# If union type, extract inner type
585+
if get_origin(typ) is Union:
586+
raise TypeError(
587+
"Model __call__ signature result type cannot be a Union type without a concrete property. Please define a property 'result_type' on the model."
588+
)
589+
590+
# Ensure subclass of ResultBase
554591
if not isclass(typ) or not issubclass(typ, ResultBase):
555592
raise TypeError(f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received {typ}.")
556-
return typ
557593

558594
@Flow.deps
559595
def __deps__(
@@ -615,28 +651,45 @@ def _determine_context_result(cls):
615651
if not hasattr(cls, "_context_type") or not hasattr(cls, "_result_type"):
616652
new_context_type = None
617653
new_result_type = None
654+
618655
for base in cls.__mro__:
619656
if issubclass(base, CallableModelGenericType):
620657
# Found the generic base class, it should
621658
# have either generic parameters or context/result
622659
if new_context_type is None and hasattr(base, "_context_type") and issubclass(base._context_type, ContextBase):
623660
new_context_type = base._context_type
624-
if new_result_type is None and hasattr(base, "_result_type") and issubclass(base._result_type, ResultBase):
661+
if (
662+
new_result_type is None
663+
and hasattr(base, "_result_type")
664+
and (
665+
issubclass(base._result_type, ResultBase)
666+
or (
667+
get_origin(base._result_type) is Union
668+
and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(base._result_type))
669+
)
670+
)
671+
):
625672
new_result_type = base._result_type
626673
if base.__pydantic_generic_metadata__["args"]:
627674
if len(base.__pydantic_generic_metadata__["args"]) >= 2:
628675
# Assume order is ContextType, ResultType
629676
arg0, arg1 = base.__pydantic_generic_metadata__["args"][:2]
630677
if new_context_type is None and isinstance(arg0, type) and issubclass(arg0, ContextBase):
631678
new_context_type = arg0
632-
if new_result_type is None and isinstance(arg1, type) and issubclass(arg1, ResultBase):
679+
if new_result_type is None and (
680+
(isinstance(arg1, type) and issubclass(arg1, ResultBase))
681+
or (get_origin(arg1) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(arg1)))
682+
):
633683
# NOTE: ContextBase inherits from ResultBase, so order matters here!
634684
new_result_type = arg1
635685
else:
636686
for arg in base.__pydantic_generic_metadata__["args"]:
637687
if new_context_type is None and isinstance(arg, type) and issubclass(arg, ContextBase):
638688
new_context_type = arg
639-
elif new_result_type is None and isinstance(arg, type) and issubclass(arg, ResultBase):
689+
elif new_result_type is None and (
690+
(isinstance(arg, type) and issubclass(arg, ResultBase))
691+
or (get_origin(arg) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(arg)))
692+
):
640693
# NOTE: ContextBase inherits from ResultBase, so order matters here!
641694
new_result_type = arg
642695
if new_context_type and new_result_type:
@@ -666,11 +719,25 @@ def _determine_context_result(cls):
666719
if new_result_type is not None:
667720
# Validate that the model's result_type match
668721
annotation_result_type = _cached_signature(cls.__call__).return_annotation
669-
if (
670-
annotation_result_type is not Signature.empty
671-
and not isinstance(annotation_result_type, TypeVar)
672-
and not issubclass(annotation_result_type, new_result_type)
673-
):
722+
if annotation_result_type is Signature.empty:
723+
...
724+
elif isinstance(annotation_result_type, TypeVar):
725+
...
726+
elif get_origin(annotation_result_type) is Union and get_origin(new_result_type) is Union:
727+
raise TypeError(
728+
f"Return type annotation for __call__ cannot be union on a CallableModelGenericType with union `result_type`. Received {annotation_result_type}"
729+
)
730+
elif get_origin(annotation_result_type) is Union:
731+
if not any(issubclass(new_result_type, union_type) for union_type in get_args(annotation_result_type)):
732+
raise TypeError(
733+
f"Return type annotation {annotation_result_type} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
734+
)
735+
elif get_origin(new_result_type) is Union:
736+
if not any(issubclass(annotation_result_type, union_type) for union_type in get_args(new_result_type)):
737+
raise TypeError(
738+
f"Return type annotation {annotation_result_type} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
739+
)
740+
elif not issubclass(annotation_result_type, new_result_type):
674741
raise TypeError(
675742
f"Return type annotation {annotation_result_type} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
676743
)

ccflow/tests/test_callable.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pickle import dumps as pdumps, loads as ploads
2-
from typing import Generic, List, Optional, Tuple, Type, TypeVar
2+
from typing import Generic, List, Optional, Tuple, Type, TypeVar, Union
33
from unittest import TestCase
44

55
import ray
@@ -268,6 +268,46 @@ def __call__(self, context: NullContext) -> GenericResult[int]:
268268
return GenericResult[int](value=42)
269269

270270

271+
class AResult(ResultBase):
272+
a: int
273+
274+
275+
class BResult(ResultBase):
276+
b: str
277+
278+
279+
class UnionReturn(CallableModel):
280+
@property
281+
def result_type(self) -> Type[ResultType]:
282+
return AResult
283+
284+
@Flow.call
285+
def __call__(self, context: NullContext) -> Union[AResult, BResult]:
286+
# Return one branch of the Union
287+
return AResult(a=1)
288+
289+
290+
class BadModelUnionReturnNoProperty(CallableModel):
291+
@Flow.call
292+
def __call__(self, context: NullContext) -> Union[AResult, BResult]:
293+
# Return one branch of the Union
294+
return AResult(a=1)
295+
296+
297+
class UnionReturnGeneric(CallableModelGenericType[NullContext, AResult]):
298+
@Flow.call
299+
def __call__(self, context: NullContext) -> Union[AResult, BResult]:
300+
# Return one branch of the Union
301+
return AResult(a=1)
302+
303+
304+
class BadModelUnionReturnGeneric(CallableModelGenericType[NullContext, Union[AResult, BResult]]):
305+
@Flow.call
306+
def __call__(self, context: NullContext) -> Union[AResult, BResult]:
307+
# Return one branch of the Union
308+
return AResult(a=1)
309+
310+
271311
class MyGenericContext(ContextBase, Generic[TContext]):
272312
value: TContext
273313

@@ -419,14 +459,11 @@ def test_types(self):
419459
error = "The context_type <class 'ccflow.context.NullContext'> must match the type of the context accepted by __call__ <class 'ccflow.tests.test_callable.MyContext'>"
420460
self.assertRaisesRegex(ValueError, error, BadModelMismatchedContextAndCall)
421461

422-
error = "Context type annotation <class 'ccflow.tests.test_callable.MyContext'> on __call__ does not match context_type <class 'ccflow.context.NullContext'> defined by CallableModelGenericType"
423-
self.assertRaisesRegex(TypeError, error, BadModelGenericMismatchedContextAndCall)
424-
425462
error = "The result_type <class 'ccflow.result.generic.GenericResult'> must match the return type of __call__ <class 'ccflow.tests.test_callable.MyResult'>"
426463
self.assertRaisesRegex(ValueError, error, BadModelMismatchedResultAndCall)
427464

428-
error = "Return type annotation <class 'ccflow.tests.test_callable.MyResult'> on __call__ does not match result_type <class 'ccflow.result.generic.GenericResult'> defined by CallableModelGenericType"
429-
self.assertRaisesRegex(TypeError, error, BadModelGenericMismatchedResultAndCall)
465+
error = "Model __call__ signature result type cannot be a Union type without a concrete property. Please define a property 'result_type' on the model."
466+
self.assertRaisesRegex(TypeError, error, BadModelUnionReturnNoProperty)
430467

431468
def test_identity(self):
432469
# Make sure that an "identity" mapping works
@@ -441,6 +478,12 @@ def test_context_call_match_enforcement_generic_base(self):
441478
# This should not raise
442479
_ = ModelMixedGenericsEnforceContextMatch(model=IdentityCallable())
443480

481+
def test_union_return(self):
482+
m = UnionReturn()
483+
result = m(NullContext())
484+
self.assertIsInstance(result, AResult)
485+
self.assertEqual(result.a, 1)
486+
444487

445488
class TestWrapperModel(TestCase):
446489
def test_wrapper(self):
@@ -594,6 +637,22 @@ def __call__(self, context: NullContext) -> GenericResult[float]:
594637
with self.assertRaises(TypeError):
595638
MyCallable()
596639

640+
def test_types_generic(self):
641+
error = "Context type annotation <class 'ccflow.tests.test_callable.MyContext'> on __call__ does not match context_type <class 'ccflow.context.NullContext'> defined by CallableModelGenericType"
642+
self.assertRaisesRegex(TypeError, error, BadModelGenericMismatchedContextAndCall)
643+
644+
error = "Return type annotation <class 'ccflow.tests.test_callable.MyResult'> on __call__ does not match result_type <class 'ccflow.result.generic.GenericResult'> defined by CallableModelGenericType"
645+
self.assertRaisesRegex(TypeError, error, BadModelGenericMismatchedResultAndCall)
646+
647+
error = "Return type annotation for __call__ cannot be union on a CallableModelGenericType with union `result_type`"
648+
self.assertRaisesRegex(TypeError, error, BadModelUnionReturnGeneric)
649+
650+
def test_union_return_generic(self):
651+
m = UnionReturnGeneric()
652+
result = m(NullContext())
653+
self.assertIsInstance(result, AResult)
654+
self.assertEqual(result.a, 1)
655+
597656

598657
class TestCallableModelDeps(TestCase):
599658
def test_basic(self):

0 commit comments

Comments
 (0)