@@ -119,12 +119,27 @@ def _check_signature(self):
119119
120120 if hasattr (self , "result_type" ):
121121 type_call_return = _cached_signature (self .__class__ .__call__ ).return_annotation
122- if (
123- not isinstance (type_call_return , TypeVar )
124- and type_call_return is not Signature .empty
125- and (not isclass (type_call_return ) or not issubclass (type_call_return , self .result_type ))
126- and (not isclass (self .result_type ) or not issubclass (self .result_type , type_call_return ))
127- ):
122+
123+ # If union, check all types
124+ if get_origin (type_call_return ) is Union and get_args (type_call_return ):
125+ types_call_return = [t for t in get_args (type_call_return ) if t is not type (None )]
126+ else :
127+ types_call_return = [type_call_return ]
128+
129+ all_bad = True
130+ for type_call_return in types_call_return :
131+ if (
132+ not isinstance (type_call_return , TypeVar )
133+ and type_call_return is not Signature .empty
134+ and (not isclass (type_call_return ) or not issubclass (type_call_return , self .result_type ))
135+ and (not isclass (self .result_type ) or not issubclass (self .result_type , type_call_return ))
136+ ):
137+ # Don't invert logic so that we match context above
138+ pass
139+ else :
140+ all_bad = False
141+
142+ if all_bad :
128143 err_msg_type_mismatch = f"The result_type { self .result_type } must match the return type of __call__ { type_call_return } "
129144 raise ValueError (err_msg_type_mismatch )
130145
@@ -251,7 +266,9 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
251266 get_origin (model .context_type ) is Union and type (None ) in get_args (model .context_type )
252267 ):
253268 raise TypeError (f"Context type { model .context_type } must be a subclass of ContextBase" )
254- if not isclass (model .result_type ) or not issubclass (model .result_type , ResultBase ):
269+ if (not isclass (model .result_type ) or not issubclass (model .result_type , ResultBase )) and not (
270+ get_origin (model .result_type ) is Union and all (isclass (t ) and issubclass (t , ResultBase ) for t in get_args (model .result_type ))
271+ ):
255272 raise TypeError (f"Result type { model .result_type } must be a subclass of ResultBase" )
256273 if self ._deps and fn .__name__ != "__deps__" :
257274 raise ValueError ("Can only apply Flow.deps decorator to __deps__" )
@@ -457,7 +474,10 @@ def __call__(self) -> ResultType:
457474 elif hasattr (result , "_lazy_is_delayed" ):
458475 object .__setattr__ (result , "_lazy_validation_requested" , True )
459476 elif hasattr (self .model , "result_type" ):
460- result = self .model .result_type .model_validate (result )
477+ result_type = self .model .result_type
478+ if not isclass (result_type ) or not issubclass (result_type , ResultBase ):
479+ raise TypeError (f"Model result_type { result_type } is not a subclass of ResultBase" )
480+ result = result_type .model_validate (result )
461481
462482 return result
463483 else :
@@ -530,16 +550,20 @@ def context_type(self) -> Type[ContextType]:
530550 if typ is Signature .empty :
531551 raise TypeError ("Must either define a type annotation for context on __call__ or implement 'context_type'" )
532552
553+ self ._check_context_type (typ )
554+ return typ
555+
556+ @staticmethod
557+ def _check_context_type (typ ):
533558 # If optional type, extract inner type
534559 if get_origin (typ ) is Optional or (get_origin (typ ) is Union and type (None ) in get_args (typ )):
535- typ_to_check = [t for t in get_args (typ ) if t is not type (None )][0 ]
560+ type_to_check = [t for t in get_args (typ ) if t is not type (None )][0 ]
536561 else :
537- typ_to_check = typ
562+ type_to_check = typ
538563
539564 # Ensure subclass of ContextBase
540- if not isclass (typ_to_check ) or not issubclass (typ_to_check , ContextBase ):
541- raise TypeError (f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received { typ_to_check } ." )
542- return typ
565+ if not isclass (type_to_check ) or not issubclass (type_to_check , ContextBase ):
566+ raise TypeError (f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received { type_to_check } ." )
543567
544568 @property
545569 def result_type (self ) -> Type [ResultType ]:
@@ -551,9 +575,21 @@ def result_type(self) -> Type[ResultType]:
551575 typ = _cached_signature (self .__class__ .__call__ ).return_annotation
552576 if typ is Signature .empty :
553577 raise TypeError ("Must either define a return type annotation on __call__ or implement 'result_type'" )
578+
579+ self ._check_result_type (typ )
580+ return typ
581+
582+ @staticmethod
583+ def _check_result_type (typ ):
584+ # If union type, extract inner type
585+ if get_origin (typ ) is Union :
586+ raise TypeError (
587+ "Model __call__ signature result type cannot be a Union type without a concrete property. Please define a property 'result_type' on the model."
588+ )
589+
590+ # Ensure subclass of ResultBase
554591 if not isclass (typ ) or not issubclass (typ , ResultBase ):
555592 raise TypeError (f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received { typ } ." )
556- return typ
557593
558594 @Flow .deps
559595 def __deps__ (
@@ -615,28 +651,45 @@ def _determine_context_result(cls):
615651 if not hasattr (cls , "_context_type" ) or not hasattr (cls , "_result_type" ):
616652 new_context_type = None
617653 new_result_type = None
654+
618655 for base in cls .__mro__ :
619656 if issubclass (base , CallableModelGenericType ):
620657 # Found the generic base class, it should
621658 # have either generic parameters or context/result
622659 if new_context_type is None and hasattr (base , "_context_type" ) and issubclass (base ._context_type , ContextBase ):
623660 new_context_type = base ._context_type
624- if new_result_type is None and hasattr (base , "_result_type" ) and issubclass (base ._result_type , ResultBase ):
661+ if (
662+ new_result_type is None
663+ and hasattr (base , "_result_type" )
664+ and (
665+ issubclass (base ._result_type , ResultBase )
666+ or (
667+ get_origin (base ._result_type ) is Union
668+ and all (isclass (t ) and issubclass (t , ResultBase ) for t in get_args (base ._result_type ))
669+ )
670+ )
671+ ):
625672 new_result_type = base ._result_type
626673 if base .__pydantic_generic_metadata__ ["args" ]:
627674 if len (base .__pydantic_generic_metadata__ ["args" ]) >= 2 :
628675 # Assume order is ContextType, ResultType
629676 arg0 , arg1 = base .__pydantic_generic_metadata__ ["args" ][:2 ]
630677 if new_context_type is None and isinstance (arg0 , type ) and issubclass (arg0 , ContextBase ):
631678 new_context_type = arg0
632- if new_result_type is None and isinstance (arg1 , type ) and issubclass (arg1 , ResultBase ):
679+ if new_result_type is None and (
680+ (isinstance (arg1 , type ) and issubclass (arg1 , ResultBase ))
681+ or (get_origin (arg1 ) is Union and all (isclass (t ) and issubclass (t , ResultBase ) for t in get_args (arg1 )))
682+ ):
633683 # NOTE: ContextBase inherits from ResultBase, so order matters here!
634684 new_result_type = arg1
635685 else :
636686 for arg in base .__pydantic_generic_metadata__ ["args" ]:
637687 if new_context_type is None and isinstance (arg , type ) and issubclass (arg , ContextBase ):
638688 new_context_type = arg
639- elif new_result_type is None and isinstance (arg , type ) and issubclass (arg , ResultBase ):
689+ elif new_result_type is None and (
690+ (isinstance (arg , type ) and issubclass (arg , ResultBase ))
691+ or (get_origin (arg ) is Union and all (isclass (t ) and issubclass (t , ResultBase ) for t in get_args (arg )))
692+ ):
640693 # NOTE: ContextBase inherits from ResultBase, so order matters here!
641694 new_result_type = arg
642695 if new_context_type and new_result_type :
@@ -666,11 +719,25 @@ def _determine_context_result(cls):
666719 if new_result_type is not None :
667720 # Validate that the model's result_type match
668721 annotation_result_type = _cached_signature (cls .__call__ ).return_annotation
669- if (
670- annotation_result_type is not Signature .empty
671- and not isinstance (annotation_result_type , TypeVar )
672- and not issubclass (annotation_result_type , new_result_type )
673- ):
722+ if annotation_result_type is Signature .empty :
723+ ...
724+ elif isinstance (annotation_result_type , TypeVar ):
725+ ...
726+ elif get_origin (annotation_result_type ) is Union and get_origin (new_result_type ) is Union :
727+ raise TypeError (
728+ f"Return type annotation for __call__ cannot be union on a CallableModelGenericType with union `result_type`. Received { annotation_result_type } "
729+ )
730+ elif get_origin (annotation_result_type ) is Union :
731+ if not any (issubclass (new_result_type , union_type ) for union_type in get_args (annotation_result_type )):
732+ raise TypeError (
733+ f"Return type annotation { annotation_result_type } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
734+ )
735+ elif get_origin (new_result_type ) is Union :
736+ if not any (issubclass (annotation_result_type , union_type ) for union_type in get_args (new_result_type )):
737+ raise TypeError (
738+ f"Return type annotation { annotation_result_type } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
739+ )
740+ elif not issubclass (annotation_result_type , new_result_type ):
674741 raise TypeError (
675742 f"Return type annotation { annotation_result_type } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
676743 )
0 commit comments