1515import logging
1616from functools import lru_cache , wraps
1717from inspect import Signature , isclass , signature
18- from typing import Any , ClassVar , Dict , Generic , List , Optional , Tuple , Type , TypeVar
18+ from typing import Any , ClassVar , Dict , Generic , List , Optional , Tuple , Type , TypeVar , Union , get_args , get_origin
1919
2020from pydantic import BaseModel as PydanticBaseModel , ConfigDict , Field , InstanceOf , PrivateAttr , TypeAdapter , field_validator , model_validator
2121from typing_extensions import override
@@ -217,7 +217,9 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di
217217 def wrapper (model , context = Signature .empty , * , _options : Optional [FlowOptions ] = None , ** kwargs ):
218218 if not isinstance (model , CallableModel ):
219219 raise TypeError (f"Can only decorate methods on CallableModels (not { type (model )} ) with the flow decorator." )
220- if not isclass (model .context_type ) or not issubclass (model .context_type , ContextBase ):
220+ if (not isclass (model .context_type ) or not issubclass (model .context_type , ContextBase )) and not (
221+ get_origin (model .context_type ) is Union and type (None ) in get_args (model .context_type )
222+ ):
221223 raise TypeError (f"Context type { model .context_type } must be a subclass of ContextBase" )
222224 if not isclass (model .result_type ) or not issubclass (model .result_type , ResultBase ):
223225 raise TypeError (f"Result type { model .result_type } must be a subclass of ResultBase" )
@@ -237,7 +239,11 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
237239
238240 # Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message
239241 if not isinstance (context , model .context_type ):
240- context = model .context_type .model_validate (context )
242+ if get_origin (model .context_type ) is Union and type (None ) in get_args (model .context_type ):
243+ model_context_type = [t for t in get_args (model .context_type ) if t is not type (None )][0 ]
244+ else :
245+ model_context_type = model .context_type
246+ context = model_context_type .model_validate (context )
241247
242248 if fn != getattr (model .__class__ , fn .__name__ ).__wrapped__ :
243249 # This happens when super().__call__ is used when implementing a CallableModel that derives from another one.
@@ -385,7 +391,8 @@ class ModelEvaluationContext(
385391 fn : str = Field ("__call__" , strict = True )
386392 options : Dict [str , Any ] = Field (default_factory = dict )
387393 model : InstanceOf [_CallableModel ]
388- context : InstanceOf [ContextBase ]
394+ context : Union [InstanceOf [ContextBase ], None ]
395+
389396 # Using InstanceOf instead of the actual type will limit Pydantic's validation of the field to instance checking
390397 # Otherwise, the validation will re-run fully despite the models already being validated on construction
391398 # TODO: Make the instance check compatible with the generic types instead of the base type
@@ -492,9 +499,15 @@ def context_type(self) -> Type[ContextType]:
492499 typ = _cached_signature (self .__class__ .__call__ ).parameters ["context" ].annotation
493500 if typ is Signature .empty :
494501 raise TypeError ("Must either define a type annotation for context on __call__ or implement 'context_type'" )
495- if not issubclass (typ , ContextBase ):
496- raise TypeError (f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received { typ } ." )
497502
503+ # If optional type, extract inner type
504+ if get_origin (typ ) is Optional or (get_origin (typ ) is Union and type (None ) in get_args (typ )):
505+ typ_to_check = [t for t in get_args (typ ) if t is not type (None )][0 ]
506+ else :
507+ typ_to_check = typ
508+ # Ensure subclass of ContextBase
509+ if not issubclass (typ_to_check , ContextBase ):
510+ raise TypeError (f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received { typ_to_check } ." )
498511 return typ
499512
500513 @property
0 commit comments