Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 164 additions & 15 deletions ccflow/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
"""

import abc
import inspect
import logging
from functools import lru_cache, wraps
from inspect import Signature, isclass, signature
from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin

from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator
from typing_extensions import override
Expand All @@ -27,6 +28,7 @@
ResultBase,
ResultType,
)
from .local_persistence import create_ccflow_model
from .validators import str_to_log_level

__all__ = (
Expand Down Expand Up @@ -268,14 +270,31 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di
def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs):
if not isinstance(model, CallableModel):
raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.")
if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not (
get_origin(model.context_type) is Union and type(None) in get_args(model.context_type)
):
raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase")
if (not isclass(model.result_type) or not issubclass(model.result_type, ResultBase)) and not (
get_origin(model.result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(model.result_type))

# Check if this is an auto_context decorated method
has_auto_context = hasattr(fn, "__auto_context__")
if has_auto_context:
method_context_type = fn.__auto_context__
else:
method_context_type = model.context_type

# Validate context type (skip for auto contexts which are always valid ContextBase subclasses)
if not has_auto_context:
if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not (
get_origin(model.context_type) is Union and type(None) in get_args(model.context_type)
):
raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase")

# Validate result type - use __result_type__ for auto contexts if available
if has_auto_context and hasattr(fn, "__result_type__"):
method_result_type = fn.__result_type__
else:
method_result_type = model.result_type
if (not isclass(method_result_type) or not issubclass(method_result_type, ResultBase)) and not (
get_origin(method_result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(method_result_type))
):
raise TypeError(f"Result type {model.result_type} must be a subclass of ResultBase")
raise TypeError(f"Result type {method_result_type} must be a subclass of ResultBase")

if self._deps and fn.__name__ != "__deps__":
raise ValueError("Can only apply Flow.deps decorator to __deps__")
if context is Signature.empty:
Expand All @@ -285,18 +304,18 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
context = kwargs
else:
raise TypeError(
f"{fn.__name__}() missing 1 required positional argument: 'context' of type {model.context_type}, or kwargs to construct it"
f"{fn.__name__}() missing 1 required positional argument: 'context' of type {method_context_type}, or kwargs to construct it"
)
elif kwargs: # Kwargs passed in as well as context. Not allowed
raise TypeError(f"{fn.__name__}() was passed a context and got an unexpected keyword argument '{next(iter(kwargs.keys()))}'")

# Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message
if not isinstance(context, model.context_type):
if get_origin(model.context_type) is Union and type(None) in get_args(model.context_type):
model_context_type = [t for t in get_args(model.context_type) if t is not type(None)][0]
if not isinstance(context, method_context_type):
if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type):
coerce_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0]
else:
model_context_type = model.context_type
context = model_context_type.model_validate(context)
coerce_context_type = method_context_type
context = coerce_context_type.model_validate(context)

if fn != getattr(model.__class__, fn.__name__).__wrapped__:
# This happens when super().__call__ is used when implementing a CallableModel that derives from another one.
Expand All @@ -313,6 +332,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
wrap.get_evaluator = self.get_evaluator
wrap.get_options = self.get_options
wrap.get_evaluation_context = get_evaluation_context

# Preserve auto context attributes for introspection
if hasattr(fn, "__auto_context__"):
wrap.__auto_context__ = fn.__auto_context__
if hasattr(fn, "__result_type__"):
wrap.__result_type__ = fn.__result_type__

return wrap


Expand Down Expand Up @@ -391,7 +417,58 @@ def __exit__(self, exc_type, exc_value, exc_tb):
class Flow(PydanticBaseModel):
@staticmethod
def call(*args, **kwargs):
"""Decorator for methods on callable models"""
"""Decorator for methods on callable models.

Args:
auto_context: Controls automatic context class generation from the function
signature. Accepts three types of values:
- False (default): No auto-generation, use traditional context parameter
- True: Auto-generate context class with no parent
- ContextBase subclass: Auto-generate context class inheriting from this parent
**kwargs: Additional FlowOptions parameters (log_level, verbose, validate_result,
cacheable, evaluator, volatile).

Basic Example:
class MyModel(CallableModel):
@Flow.call
def __call__(self, context: MyContext) -> MyResult:
return MyResult(value=context.x)

Auto Context Example:
class MyModel(CallableModel):
@Flow.call(auto_context=True)
def __call__(self, *, x: int, y: str = "default") -> MyResult:
return MyResult(value=f"{x}-{y}")

model = MyModel()
model(x=42) # Call with kwargs directly

With Parent Context:
class MyModel(CallableModel):
@Flow.call(auto_context=DateContext)
def __call__(self, *, date: date, extra: int = 0) -> MyResult:
return MyResult(value=date.day + extra)

# The generated context inherits from DateContext, so it's compatible
# with infrastructure expecting DateContext instances.
"""
# Extract auto_context option (not part of FlowOptions)
# Can be: False, True, or a ContextBase subclass
auto_context = kwargs.pop("auto_context", False)

# Determine if auto_context is enabled and extract parent class if provided
if auto_context is False:
auto_context_enabled = False
context_parent = None
elif auto_context is True:
auto_context_enabled = True
context_parent = None
elif isclass(auto_context) and issubclass(auto_context, ContextBase):
auto_context_enabled = True
context_parent = auto_context
else:
raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}")

if len(args) == 1 and callable(args[0]):
# No arguments to decorator, this is the decorator
fn = args[0]
Expand All @@ -400,6 +477,14 @@ def call(*args, **kwargs):
else:
# Arguments to decorator, this is just returning the decorator
# Note that the code below is executed only once
if auto_context_enabled:
# Return a decorator that first applies auto_context, then FlowOptions
def auto_context_decorator(fn):
wrapped = _apply_auto_context(fn, parent=context_parent)
# FlowOptions.__call__ already applies wraps, so we just return its result
return FlowOptions(**kwargs)(wrapped)

return auto_context_decorator
return FlowOptions(**kwargs)

@staticmethod
Expand Down Expand Up @@ -754,3 +839,67 @@ def _validate_callable_model_generic_type(cls, m, handler, info):


CallableModelGenericType = CallableModelGeneric


# *****************************************************************************
# Auto Context (internal helper for Flow.call(auto_context=True))
# *****************************************************************************


def _apply_auto_context(func: Callable, *, parent: Type[ContextBase] = None) -> Callable:
"""Internal function that creates an auto context class from function parameters.

This function extracts the parameters from a function signature and creates
a ContextBase subclass whose fields correspond to those parameters.
The decorated function is then wrapped to accept the context object and
unpack it into keyword arguments.

Used internally by Flow.call(auto_context=...).

Example:
class MyCallable(CallableModel):
@Flow.call(auto_context=True)
def __call__(self, *, x: int, y: str = "default") -> GenericResult:
return GenericResult(value=f"{x}-{y}")

model = MyCallable()
model(x=42, y="hello") # Works with kwargs
"""
sig = signature(func)
base_class = parent or ContextBase

# Validate parent fields are in function signature
if parent is not None:
parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys())
sig_params = set(sig.parameters.keys()) - {"self"}
missing = parent_fields - sig_params
if missing:
raise TypeError(f"Parent context fields {missing} must be included in function signature")

# Build fields from parameters (skip 'self'), pydantic validates types
fields = {}
for name, param in sig.parameters.items():
if name == "self":
continue
default = ... if param.default is inspect.Parameter.empty else param.default
fields[name] = (param.annotation, default)

# Create auto context class
auto_context_class = create_ccflow_model(f"{func.__qualname__}_AutoContext", __base__=base_class, **fields)

@wraps(func)
def wrapper(self, context):
fn_kwargs = {name: getattr(context, name) for name in fields}
return func(self, **fn_kwargs)

# Must set __signature__ so CallableModel validation sees 'context' parameter
wrapper.__signature__ = inspect.Signature(
parameters=[
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD),
inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class),
],
return_annotation=sig.return_annotation,
)
wrapper.__auto_context__ = auto_context_class
wrapper.__result_type__ = sig.return_annotation
return wrapper
Loading