From de87bb26551e8e4feba8c2ad3dbff99f701446e4 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Wed, 31 Dec 2025 03:56:58 -0500 Subject: [PATCH 1/2] Add ability to define dynamic context from kwargs Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 165 ++++++++++++-- ccflow/tests/test_callable.py | 381 +++++++++++++++++++++++++++++++++ ccflow/tests/test_evaluator.py | 70 +++++- 3 files changed, 599 insertions(+), 17 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index b09eaea..b6580c9 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -12,10 +12,11 @@ """ import abc +import inspect import logging -from functools import lru_cache, wraps +from functools import lru_cache, partial, wraps from inspect import Signature, isclass, signature -from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator from typing_extensions import override @@ -27,6 +28,7 @@ ResultBase, ResultType, ) +from .local_persistence import create_ccflow_model from .validators import str_to_log_level __all__ = ( @@ -44,6 +46,7 @@ "EvaluatorBase", "Evaluator", "WrapperModel", + "dynamic_context", ) log = logging.getLogger(__name__) @@ -268,14 +271,31 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs): if not isinstance(model, CallableModel): raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.") - if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( - get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) - ): - raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") - if (not isclass(model.result_type) or not issubclass(model.result_type, ResultBase)) and not ( - get_origin(model.result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(model.result_type)) + + # Check if this is a dynamic_context decorated method + has_dynamic_context = hasattr(fn, "__dynamic_context__") + if has_dynamic_context: + method_context_type = fn.__dynamic_context__ + else: + method_context_type = model.context_type + + # Validate context type (skip for dynamic contexts which are always valid ContextBase subclasses) + if not has_dynamic_context: + if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( + get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) + ): + raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") + + # Validate result type - use __result_type__ for dynamic contexts if available + if has_dynamic_context and hasattr(fn, "__result_type__"): + method_result_type = fn.__result_type__ + else: + method_result_type = model.result_type + if (not isclass(method_result_type) or not issubclass(method_result_type, ResultBase)) and not ( + get_origin(method_result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(method_result_type)) ): - raise TypeError(f"Result type {model.result_type} must be a subclass of ResultBase") + raise TypeError(f"Result type {method_result_type} must be a subclass of ResultBase") + if self._deps and fn.__name__ != "__deps__": raise ValueError("Can only apply Flow.deps decorator to __deps__") if context is Signature.empty: @@ -285,18 +305,18 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = context = kwargs else: raise TypeError( - f"{fn.__name__}() missing 1 required positional argument: 'context' of type {model.context_type}, or kwargs to construct it" + f"{fn.__name__}() missing 1 required positional argument: 'context' of type {method_context_type}, or kwargs to construct it" ) elif kwargs: # Kwargs passed in as well as context. Not allowed raise TypeError(f"{fn.__name__}() was passed a context and got an unexpected keyword argument '{next(iter(kwargs.keys()))}'") # Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message - if not isinstance(context, model.context_type): - if get_origin(model.context_type) is Union and type(None) in get_args(model.context_type): - model_context_type = [t for t in get_args(model.context_type) if t is not type(None)][0] + if not isinstance(context, method_context_type): + if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type): + coerce_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0] else: - model_context_type = model.context_type - context = model_context_type.model_validate(context) + coerce_context_type = method_context_type + context = coerce_context_type.model_validate(context) if fn != getattr(model.__class__, fn.__name__).__wrapped__: # This happens when super().__call__ is used when implementing a CallableModel that derives from another one. @@ -313,6 +333,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = wrap.get_evaluator = self.get_evaluator wrap.get_options = self.get_options wrap.get_evaluation_context = get_evaluation_context + + # Preserve dynamic context attributes for introspection + if hasattr(fn, "__dynamic_context__"): + wrap.__dynamic_context__ = fn.__dynamic_context__ + if hasattr(fn, "__result_type__"): + wrap.__result_type__ = fn.__result_type__ + return wrap @@ -417,6 +444,49 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def dynamic_call(*args, **kwargs): + """Decorator for methods that creates a dynamic context from the function signature. + + This combines @Flow.call and @dynamic_context into a single decorator, allowing + you to define the context inline in the function signature instead of creating + a separate context class. + + Example: + class MyModel(CallableModel): + @Flow.dynamic_call + def __call__(self, *, a: int, b: str = "default") -> GenericResult: + return GenericResult(value=f"{a}-{b}") + + model = MyModel() + model(a=42) # Works with kwargs + model(a=42, b="test") # Also works + + Args: + *args: When used without arguments, the decorated function + **kwargs: FlowOptions parameters (log_level, verbose, validate_result, etc.) + plus dynamic_context options: + - parent: Optional parent context class to inherit from + """ + # Import here to avoid circular import at module level + from ccflow.callable import dynamic_context + + # Extract dynamic_context-specific options + parent = kwargs.pop("parent", None) + + if len(args) == 1 and callable(args[0]): + # No arguments to decorator (@Flow.dynamic_call) + fn = args[0] + wrapped = dynamic_context(fn, parent=parent) + return Flow.call(wrapped) + else: + # Arguments to decorator (@Flow.dynamic_call(...)) + def decorator(fn): + wrapped = dynamic_context(fn, parent=parent) + return Flow.call(**kwargs)(wrapped) + + return decorator + # ***************************************************************************** # Define "Evaluators" and associated types @@ -754,3 +824,68 @@ def _validate_callable_model_generic_type(cls, m, handler, info): CallableModelGenericType = CallableModelGeneric + + +# ***************************************************************************** +# Dynamic Context Decorator +# ***************************************************************************** + + +def dynamic_context(func: Callable = None, *, parent: Type[ContextBase] = None) -> Callable: + """Decorator that creates a dynamic context class from function parameters. + + This decorator extracts the parameters from a function signature and creates + a dynamic ContextBase subclass whose fields correspond to those parameters. + The decorated function is then wrapped to accept the context object and + unpack it into keyword arguments. + + Example: + class MyCallable(CallableModel): + @Flow.dynamic_call # or @Flow.call @dynamic_context + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = MyCallable() + model(x=42, y="hello") # Works with kwargs + """ + if func is None: + return partial(dynamic_context, parent=parent) + + sig = signature(func) + base_class = parent or ContextBase + + # Validate parent fields are in function signature + if parent is not None: + parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys()) + sig_params = set(sig.parameters.keys()) - {"self"} + missing = parent_fields - sig_params + if missing: + raise TypeError(f"Parent context fields {missing} must be included in function signature") + + # Build fields from parameters (skip 'self'), pydantic validates types + fields = {} + for name, param in sig.parameters.items(): + if name == "self": + continue + default = ... if param.default is inspect.Parameter.empty else param.default + fields[name] = (param.annotation, default) + + # Create dynamic context class + dyn_context = create_ccflow_model(f"{func.__qualname__}_DynamicContext", __base__=base_class, **fields) + + @wraps(func) + def wrapper(self, context): + fn_kwargs = {name: getattr(context, name) for name in fields} + return func(self, **fn_kwargs) + + # Must set __signature__ so CallableModel validation sees 'context' parameter + wrapper.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=dyn_context), + ], + return_annotation=sig.return_annotation, + ) + wrapper.__dynamic_context__ = dyn_context + wrapper.__result_type__ = sig.return_annotation + return wrapper diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 43f86b5..444d496 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -20,6 +20,7 @@ ResultBase, ResultType, WrapperModel, + dynamic_context, ) from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME @@ -783,3 +784,383 @@ class MyCallableParent_bad_decorator(MyCallableParent): @Flow.deps def foo(self, context): return [] + + +# ============================================================================= +# Tests for dynamic_context decorator +# ============================================================================= + + +class TestDynamicContext(TestCase): + """Tests for the @dynamic_context decorator.""" + + def test_basic_usage_with_kwargs(self): + """Test basic dynamic_context usage with keyword arguments.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = DynamicCallable() + + # Call with kwargs + result = model(x=42, y="hello") + self.assertEqual(result.value, "42-hello") + + # Call with default + result = model(x=10) + self.assertEqual(result.value, "10-default") + + def test_dynamic_context_attribute(self): + """Test that __dynamic_context__ attribute is set.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, a: int, b: str) -> GenericResult: + return GenericResult(value=f"{a}-{b}") + + # The __call__ method should have __dynamic_context__ + call_method = DynamicCallable.__call__ + self.assertTrue(hasattr(call_method, "__wrapped__")) + # Access the inner function's __dynamic_context__ + inner = call_method.__wrapped__ + self.assertTrue(hasattr(inner, "__dynamic_context__")) + + dyn_ctx = inner.__dynamic_context__ + self.assertTrue(issubclass(dyn_ctx, ContextBase)) + self.assertIn("a", dyn_ctx.model_fields) + self.assertIn("b", dyn_ctx.model_fields) + + def test_dynamic_context_is_registered(self): + """Test that the dynamic context is registered for serialization.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + inner = DynamicCallable.__call__.__wrapped__ + dyn_ctx = inner.__dynamic_context__ + + # Should have __ccflow_import_path__ set + self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) + self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + def test_call_with_context_object(self): + """Test calling with a context object instead of kwargs.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = DynamicCallable() + + # Get the dynamic context class + dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + + # Create a context object + ctx = dyn_ctx(x=99, y="context") + result = model(ctx) + self.assertEqual(result.value, "99-context") + + def test_with_parent_context(self): + """Test dynamic_context with parent context class.""" + + class ParentContext(ContextBase): + base_value: str = "base" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context(parent=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + # Get dynamic context + dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + + # Should inherit from ParentContext + self.assertTrue(issubclass(dyn_ctx, ParentContext)) + + # Should have both fields + self.assertIn("base_value", dyn_ctx.model_fields) + self.assertIn("x", dyn_ctx.model_fields) + + # Create context with parent field + ctx = dyn_ctx(x=42, base_value="custom") + self.assertEqual(ctx.base_value, "custom") + self.assertEqual(ctx.x, 42) + + def test_parent_fields_must_be_in_signature(self): + """Test that parent fields must be included in function signature.""" + + class ParentContext(ContextBase): + required_field: str + + with self.assertRaises(TypeError) as cm: + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context(parent=ParentContext) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + self.assertIn("required_field", str(cm.exception)) + + def test_cloudpickle_roundtrip(self): + """Test cloudpickle roundtrip for dynamic context callable.""" + + class DynamicCallable(CallableModel): + multiplier: int = 2 + + @Flow.call + @dynamic_context + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + model = DynamicCallable(multiplier=3) + + # Test roundtrip + restored = rcploads(rcpdumps(model)) + + result = restored(x=10) + self.assertEqual(result.value, 30) + + def test_ray_task_execution(self): + """Test dynamic context callable in Ray task.""" + + class DynamicCallable(CallableModel): + factor: int = 2 + + @Flow.call + @dynamic_context + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + model = DynamicCallable(factor=5) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, x=10, y=2)) + + self.assertEqual(result, 60) # (10 + 2) * 5 + + def test_multiple_dynamic_context_methods(self): + """Test callable with multiple dynamic_context decorated methods.""" + + class MultiMethodCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, a: int) -> GenericResult: + return GenericResult(value=a) + + @dynamic_context + def other_method(self, *, b: str, c: float = 1.0) -> GenericResult: + return GenericResult(value=f"{b}-{c}") + + model = MultiMethodCallable() + + # Test __call__ + result1 = model(a=42) + self.assertEqual(result1.value, 42) + + # Test other_method (without Flow.call, just the dynamic_context wrapper) + # Need to create the context manually + other_ctx = model.other_method.__dynamic_context__ + ctx = other_ctx(b="hello", c=2.5) + result2 = model.other_method(ctx) + self.assertEqual(result2.value, "hello-2.5") + + def test_context_type_property_works(self): + """Test that type_ property works on the dynamic context.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + ctx = dyn_ctx(x=42) + + # type_ should work and be importable + type_path = str(ctx.type_) + self.assertIn("_Local_", type_path) + self.assertEqual(ctx.type_.object, dyn_ctx) + + def test_complex_field_types(self): + """Test dynamic_context with complex field types.""" + from typing import List, Optional + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__( + self, + *, + items: List[int], + name: Optional[str] = None, + count: int = 0, + ) -> GenericResult: + total = sum(items) + count + return GenericResult(value=f"{name}:{total}" if name else str(total)) + + model = DynamicCallable() + + result = model(items=[1, 2, 3], name="test", count=10) + self.assertEqual(result.value, "test:16") + + result = model(items=[5, 5]) + self.assertEqual(result.value, "10") + + +class TestFlowDynamicCall(TestCase): + """Tests for @Flow.dynamic_call decorator.""" + + def test_basic_usage(self): + """Test basic @Flow.dynamic_call usage.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = DynamicCallable() + + result = model(x=42, y="hello") + self.assertEqual(result.value, "42-hello") + + result = model(x=10) + self.assertEqual(result.value, "10-default") + + def test_dynamic_context_attributes_preserved(self): + """Test that __dynamic_context__ and __result_type__ are directly accessible.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + # Should be directly accessible without traversing __wrapped__ chain + method = DynamicCallable.__call__ + self.assertTrue(hasattr(method, "__dynamic_context__")) + self.assertTrue(hasattr(method, "__result_type__")) + self.assertTrue(issubclass(method.__dynamic_context__, ContextBase)) + self.assertEqual(method.__result_type__, GenericResult) + + def test_model_result_type_property(self): + """Test that model.result_type returns correct type for dynamic contexts.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + model = DynamicCallable() + self.assertEqual(model.result_type, GenericResult) + + def test_with_parent_context(self): + """Test @Flow.dynamic_call with parent context.""" + + class ParentContext(ContextBase): + base_value: str = "base" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call(parent=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + model = DynamicCallable() + + # Get dynamic context by traversing __wrapped__ chain + dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) + + # Should inherit from ParentContext + self.assertTrue(issubclass(dyn_ctx, ParentContext)) + + # Call should work, uses parent default + result = model(x=42, base_value="custom") + self.assertEqual(result.value, "42-custom") + + def test_with_flow_options(self): + """Test @Flow.dynamic_call with FlowOptions parameters.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call(validate_result=False) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + model = DynamicCallable() + result = model(x=42) + self.assertEqual(result.value, 42) + + def test_cloudpickle_roundtrip(self): + """Test cloudpickle roundtrip with @Flow.dynamic_call.""" + + class DynamicCallable(CallableModel): + multiplier: int = 2 + + @Flow.dynamic_call + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + model = DynamicCallable(multiplier=3) + restored = rcploads(rcpdumps(model)) + + result = restored(x=10) + self.assertEqual(result.value, 30) + + def test_ray_task(self): + """Test @Flow.dynamic_call in Ray task.""" + + class DynamicCallable(CallableModel): + factor: int = 2 + + @Flow.dynamic_call + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + model = DynamicCallable(factor=5) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, x=10, y=2)) + + self.assertEqual(result, 60) + + def test_dynamic_context_is_registered(self): + """Test that the dynamic context from @Flow.dynamic_call is registered.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + # Find dynamic context by traversing __wrapped__ chain + dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) + + self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) + self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + +def _find_dynamic_context(func): + """Helper to find __dynamic_context__ by traversing the __wrapped__ chain.""" + visited = set() + current = func + while current is not None and id(current) not in visited: + visited.add(id(current)) + if hasattr(current, "__dynamic_context__"): + return current.__dynamic_context__ + current = getattr(current, "__wrapped__", None) + return None diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index cc34155..34f3f7e 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -1,9 +1,21 @@ from datetime import date from unittest import TestCase -from ccflow import DateContext, Evaluator, ModelEvaluationContext +import pytest -from .evaluators.util import MyDateCallable +from ccflow import CallableModel, DateContext, Evaluator, Flow, ModelEvaluationContext + +from .evaluators.util import MyDateCallable, MyResult + + +class MyDynamicDateCallable(CallableModel): + """Dynamic context version of MyDateCallable for testing evaluators.""" + + offset: int + + @Flow.dynamic_call(parent=DateContext) + def __call__(self, *, date: date) -> MyResult: + return MyResult(x=date.day + self.offset) class TestEvaluator(TestCase): @@ -32,3 +44,57 @@ def test_evaluator_deps(self): evaluator = Evaluator() out2 = evaluator.__deps__(model_evaluation_context) self.assertEqual(out2, out) + + +@pytest.mark.parametrize( + "callable_class", + [MyDateCallable, MyDynamicDateCallable], + ids=["standard", "dynamic"], +) +class TestEvaluatorParametrized: + """Test evaluators work with both standard and dynamic context callables.""" + + def test_evaluator_with_context_object(self, callable_class): + """Test evaluator with a context object.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + + out = model_evaluation_context() + assert out == MyResult(x=2) # day 1 + offset 1 + + evaluator = Evaluator() + out2 = evaluator(model_evaluation_context) + assert out2 == out + + def test_evaluator_with_fn_specified(self, callable_class): + """Test evaluator with fn='__call__' explicitly specified.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__call__") + + out = model_evaluation_context() + assert out == MyResult(x=2) + + def test_evaluator_direct_call_matches(self, callable_class): + """Test that evaluator result matches direct call.""" + m1 = callable_class(offset=5) + context = DateContext(date=date(2022, 1, 15)) + + # Direct call + direct_result = m1(context) + + # Via evaluator + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + evaluator_result = model_evaluation_context() + + assert direct_result == evaluator_result + assert direct_result == MyResult(x=20) # day 15 + offset 5 + + def test_evaluator_with_kwargs(self, callable_class): + """Test that evaluator works when callable is called with kwargs.""" + m1 = callable_class(offset=1) + + # Call with kwargs + result = m1(date=date(2022, 1, 10)) + assert result == MyResult(x=11) # day 10 + offset 1 From 95119e335d80a3edb4c2e255eef21adb4331be7a Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Wed, 31 Dec 2025 23:19:39 -0500 Subject: [PATCH 2/2] Remove dynamic context, add option to Flow.call Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 186 +++++++-------- ccflow/tests/test_callable.py | 318 +++++++------------------ ccflow/tests/test_evaluator.py | 12 +- ccflow/tests/test_local_persistence.py | 2 +- 4 files changed, 180 insertions(+), 338 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 9b971c7..748759c 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -14,7 +14,7 @@ import abc import inspect import logging -from functools import lru_cache, partial, wraps +from functools import lru_cache, wraps from inspect import Signature, isclass, signature from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin @@ -46,7 +46,6 @@ "EvaluatorBase", "Evaluator", "WrapperModel", - "dynamic_context", ) log = logging.getLogger(__name__) @@ -272,22 +271,22 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = if not isinstance(model, CallableModel): raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.") - # Check if this is a dynamic_context decorated method - has_dynamic_context = hasattr(fn, "__dynamic_context__") - if has_dynamic_context: - method_context_type = fn.__dynamic_context__ + # Check if this is an auto_context decorated method + has_auto_context = hasattr(fn, "__auto_context__") + if has_auto_context: + method_context_type = fn.__auto_context__ else: method_context_type = model.context_type - # Validate context type (skip for dynamic contexts which are always valid ContextBase subclasses) - if not has_dynamic_context: + # Validate context type (skip for auto contexts which are always valid ContextBase subclasses) + if not has_auto_context: if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) ): raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") - # Validate result type - use __result_type__ for dynamic contexts if available - if has_dynamic_context and hasattr(fn, "__result_type__"): + # Validate result type - use __result_type__ for auto contexts if available + if has_auto_context and hasattr(fn, "__result_type__"): method_result_type = fn.__result_type__ else: method_result_type = model.result_type @@ -334,9 +333,9 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = wrap.get_options = self.get_options wrap.get_evaluation_context = get_evaluation_context - # Preserve dynamic context attributes for introspection - if hasattr(fn, "__dynamic_context__"): - wrap.__dynamic_context__ = fn.__dynamic_context__ + # Preserve auto context attributes for introspection + if hasattr(fn, "__auto_context__"): + wrap.__auto_context__ = fn.__auto_context__ if hasattr(fn, "__result_type__"): wrap.__result_type__ = fn.__result_type__ @@ -418,7 +417,58 @@ def __exit__(self, exc_type, exc_value, exc_tb): class Flow(PydanticBaseModel): @staticmethod def call(*args, **kwargs): - """Decorator for methods on callable models""" + """Decorator for methods on callable models. + + Args: + auto_context: Controls automatic context class generation from the function + signature. Accepts three types of values: + - False (default): No auto-generation, use traditional context parameter + - True: Auto-generate context class with no parent + - ContextBase subclass: Auto-generate context class inheriting from this parent + **kwargs: Additional FlowOptions parameters (log_level, verbose, validate_result, + cacheable, evaluator, volatile). + + Basic Example: + class MyModel(CallableModel): + @Flow.call + def __call__(self, context: MyContext) -> MyResult: + return MyResult(value=context.x) + + Auto Context Example: + class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> MyResult: + return MyResult(value=f"{x}-{y}") + + model = MyModel() + model(x=42) # Call with kwargs directly + + With Parent Context: + class MyModel(CallableModel): + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date, extra: int = 0) -> MyResult: + return MyResult(value=date.day + extra) + + # The generated context inherits from DateContext, so it's compatible + # with infrastructure expecting DateContext instances. + """ + # Extract auto_context option (not part of FlowOptions) + # Can be: False, True, or a ContextBase subclass + auto_context = kwargs.pop("auto_context", False) + + # Determine if auto_context is enabled and extract parent class if provided + if auto_context is False: + auto_context_enabled = False + context_parent = None + elif auto_context is True: + auto_context_enabled = True + context_parent = None + elif isclass(auto_context) and issubclass(auto_context, ContextBase): + auto_context_enabled = True + context_parent = auto_context + else: + raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}") + if len(args) == 1 and callable(args[0]): # No arguments to decorator, this is the decorator fn = args[0] @@ -427,6 +477,14 @@ def call(*args, **kwargs): else: # Arguments to decorator, this is just returning the decorator # Note that the code below is executed only once + if auto_context_enabled: + # Return a decorator that first applies auto_context, then FlowOptions + def auto_context_decorator(fn): + wrapped = _apply_auto_context(fn, parent=context_parent) + # FlowOptions.__call__ already applies wraps, so we just return its result + return FlowOptions(**kwargs)(wrapped) + + return auto_context_decorator return FlowOptions(**kwargs) @staticmethod @@ -444,81 +502,6 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) - @staticmethod - def dynamic_call(*args, **kwargs): - """Decorator that combines @Flow.call with dynamic context creation. - - Instead of defining a separate context class, this decorator creates one - automatically from the function signature. The method can then be called - with keyword arguments directly. - - Basic Example: - class MyModel(CallableModel): - @Flow.dynamic_call - def __call__(self, *, date: date, region: str = "US") -> MyResult: - return MyResult(value=f"{date}-{region}") - - model = MyModel() - model(date=date.today()) # Uses default region="US" - model(date=date.today(), region="EU") # Override default - - With Parent Context: - class MyModel(CallableModel): - @Flow.dynamic_call(parent=DateContext) - def __call__(self, *, date: date, extra: int = 0) -> MyResult: - return MyResult(value=date.day + extra) - - # Parent fields (date) must be included in the function signature. - # This is useful for integrating with existing infrastructure that - # expects specific context types. - - Args: - *args: The decorated function when used without parentheses - **kwargs: Combined options for FlowOptions and dynamic_context: - - Dynamic context options: - parent: Parent context class to inherit from. All parent fields - must appear in the function signature. - - FlowOptions (passed through to @Flow.call): - log_level: Logging level for evaluation (default: DEBUG) - verbose: Use verbose logging (default: True) - validate_result: Validate return against result_type (default: True) - cacheable: Allow result caching (default: False) - evaluator: Custom evaluator instance - - Returns: - A decorated method that accepts keyword arguments matching the signature. - - Notes: - - All parameters (except 'self') must have type annotations - - Use keyword-only parameters (after *) for cleaner signatures - - The generated context class is accessible via method.__dynamic_context__ - - The return type is accessible via method.__result_type__ - - See Also: - dynamic_context: The underlying decorator for context creation - Flow.call: The underlying decorator for flow evaluation - """ - # Import here to avoid circular import at module level - from ccflow.callable import dynamic_context - - # Extract dynamic_context-specific options - parent = kwargs.pop("parent", None) - - if len(args) == 1 and callable(args[0]): - # No arguments to decorator (@Flow.dynamic_call) - fn = args[0] - wrapped = dynamic_context(fn, parent=parent) - return Flow.call(wrapped) - else: - # Arguments to decorator (@Flow.dynamic_call(...)) - def decorator(fn): - wrapped = dynamic_context(fn, parent=parent) - return Flow.call(**kwargs)(wrapped) - - return decorator - # ***************************************************************************** # Define "Evaluators" and associated types @@ -859,30 +842,29 @@ def _validate_callable_model_generic_type(cls, m, handler, info): # ***************************************************************************** -# Dynamic Context Decorator +# Auto Context (internal helper for Flow.call(auto_context=True)) # ***************************************************************************** -def dynamic_context(func: Callable = None, *, parent: Type[ContextBase] = None) -> Callable: - """Decorator that creates a dynamic context class from function parameters. +def _apply_auto_context(func: Callable, *, parent: Type[ContextBase] = None) -> Callable: + """Internal function that creates an auto context class from function parameters. - This decorator extracts the parameters from a function signature and creates - a dynamic ContextBase subclass whose fields correspond to those parameters. + This function extracts the parameters from a function signature and creates + a ContextBase subclass whose fields correspond to those parameters. The decorated function is then wrapped to accept the context object and unpack it into keyword arguments. + Used internally by Flow.call(auto_context=...). + Example: class MyCallable(CallableModel): - @Flow.dynamic_call # or @Flow.call @dynamic_context + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: str = "default") -> GenericResult: return GenericResult(value=f"{x}-{y}") model = MyCallable() model(x=42, y="hello") # Works with kwargs """ - if func is None: - return partial(dynamic_context, parent=parent) - sig = signature(func) base_class = parent or ContextBase @@ -902,8 +884,8 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: default = ... if param.default is inspect.Parameter.empty else param.default fields[name] = (param.annotation, default) - # Create dynamic context class - dyn_context = create_ccflow_model(f"{func.__qualname__}_DynamicContext", __base__=base_class, **fields) + # Create auto context class + auto_context_class = create_ccflow_model(f"{func.__qualname__}_AutoContext", __base__=base_class, **fields) @wraps(func) def wrapper(self, context): @@ -914,10 +896,10 @@ def wrapper(self, context): wrapper.__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=dyn_context), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class), ], return_annotation=sig.return_annotation, ) - wrapper.__dynamic_context__ = dyn_context + wrapper.__auto_context__ = auto_context_class wrapper.__result_type__ = sig.return_annotation return wrapper diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 444d496..a748765 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -20,7 +20,6 @@ ResultBase, ResultType, WrapperModel, - dynamic_context, ) from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME @@ -787,23 +786,22 @@ def foo(self, context): # ============================================================================= -# Tests for dynamic_context decorator +# Tests for Flow.call(auto_context=True) # ============================================================================= -class TestDynamicContext(TestCase): - """Tests for the @dynamic_context decorator.""" +class TestAutoContext(TestCase): + """Tests for @Flow.call(auto_context=True).""" def test_basic_usage_with_kwargs(self): - """Test basic dynamic_context usage with keyword arguments.""" + """Test basic auto_context usage with keyword arguments.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: str = "default") -> GenericResult: return GenericResult(value=f"{x}-{y}") - model = DynamicCallable() + model = AutoContextCallable() # Call with kwargs result = model(x=42, y="hello") @@ -813,117 +811,111 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: result = model(x=10) self.assertEqual(result.value, "10-default") - def test_dynamic_context_attribute(self): - """Test that __dynamic_context__ attribute is set.""" + def test_auto_context_attribute(self): + """Test that __auto_context__ attribute is set.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, a: int, b: str) -> GenericResult: return GenericResult(value=f"{a}-{b}") - # The __call__ method should have __dynamic_context__ - call_method = DynamicCallable.__call__ + # The __call__ method should have __auto_context__ + call_method = AutoContextCallable.__call__ self.assertTrue(hasattr(call_method, "__wrapped__")) - # Access the inner function's __dynamic_context__ + # Access the inner function's __auto_context__ inner = call_method.__wrapped__ - self.assertTrue(hasattr(inner, "__dynamic_context__")) + self.assertTrue(hasattr(inner, "__auto_context__")) - dyn_ctx = inner.__dynamic_context__ - self.assertTrue(issubclass(dyn_ctx, ContextBase)) - self.assertIn("a", dyn_ctx.model_fields) - self.assertIn("b", dyn_ctx.model_fields) + auto_ctx = inner.__auto_context__ + self.assertTrue(issubclass(auto_ctx, ContextBase)) + self.assertIn("a", auto_ctx.model_fields) + self.assertIn("b", auto_ctx.model_fields) - def test_dynamic_context_is_registered(self): - """Test that the dynamic context is registered for serialization.""" + def test_auto_context_is_registered(self): + """Test that the auto context is registered for serialization.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, value: int) -> GenericResult: return GenericResult(value=value) - inner = DynamicCallable.__call__.__wrapped__ - dyn_ctx = inner.__dynamic_context__ + inner = AutoContextCallable.__call__.__wrapped__ + auto_ctx = inner.__auto_context__ # Should have __ccflow_import_path__ set - self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) - self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + self.assertTrue(hasattr(auto_ctx, "__ccflow_import_path__")) + self.assertTrue(auto_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) def test_call_with_context_object(self): """Test calling with a context object instead of kwargs.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: str = "default") -> GenericResult: return GenericResult(value=f"{x}-{y}") - model = DynamicCallable() + model = AutoContextCallable() - # Get the dynamic context class - dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + # Get the auto context class + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ # Create a context object - ctx = dyn_ctx(x=99, y="context") + ctx = auto_ctx(x=99, y="context") result = model(ctx) self.assertEqual(result.value, "99-context") def test_with_parent_context(self): - """Test dynamic_context with parent context class.""" + """Test auto_context with a parent context class.""" class ParentContext(ContextBase): base_value: str = "base" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context(parent=ParentContext) + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) def __call__(self, *, x: int, base_value: str) -> GenericResult: return GenericResult(value=f"{x}-{base_value}") - # Get dynamic context - dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + # Get auto context + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ # Should inherit from ParentContext - self.assertTrue(issubclass(dyn_ctx, ParentContext)) + self.assertTrue(issubclass(auto_ctx, ParentContext)) # Should have both fields - self.assertIn("base_value", dyn_ctx.model_fields) - self.assertIn("x", dyn_ctx.model_fields) + self.assertIn("base_value", auto_ctx.model_fields) + self.assertIn("x", auto_ctx.model_fields) # Create context with parent field - ctx = dyn_ctx(x=42, base_value="custom") + ctx = auto_ctx(x=42, base_value="custom") self.assertEqual(ctx.base_value, "custom") self.assertEqual(ctx.x, 42) def test_parent_fields_must_be_in_signature(self): - """Test that parent fields must be included in function signature.""" + """Test that parent context fields must be included in function signature.""" class ParentContext(ContextBase): required_field: str with self.assertRaises(TypeError) as cm: - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context(parent=ParentContext) + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x) self.assertIn("required_field", str(cm.exception)) def test_cloudpickle_roundtrip(self): - """Test cloudpickle roundtrip for dynamic context callable.""" + """Test cloudpickle roundtrip for auto_context callable.""" - class DynamicCallable(CallableModel): + class AutoContextCallable(CallableModel): multiplier: int = 2 - @Flow.call - @dynamic_context + @Flow.call(auto_context=True) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x * self.multiplier) - model = DynamicCallable(multiplier=3) + model = AutoContextCallable(multiplier=3) # Test roundtrip restored = rcploads(rcpdumps(model)) @@ -932,13 +924,12 @@ def __call__(self, *, x: int) -> GenericResult: self.assertEqual(result.value, 30) def test_ray_task_execution(self): - """Test dynamic context callable in Ray task.""" + """Test auto_context callable in Ray task.""" - class DynamicCallable(CallableModel): + class AutoContextCallable(CallableModel): factor: int = 2 - @Flow.call - @dynamic_context + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: int = 1) -> GenericResult: return GenericResult(value=(x + y) * self.factor) @@ -946,63 +937,35 @@ def __call__(self, *, x: int, y: int = 1) -> GenericResult: def run_callable(model, **kwargs): return model(**kwargs).value - model = DynamicCallable(factor=5) + model = AutoContextCallable(factor=5) with ray.init(num_cpus=1): result = ray.get(run_callable.remote(model, x=10, y=2)) self.assertEqual(result, 60) # (10 + 2) * 5 - def test_multiple_dynamic_context_methods(self): - """Test callable with multiple dynamic_context decorated methods.""" - - class MultiMethodCallable(CallableModel): - @Flow.call - @dynamic_context - def __call__(self, *, a: int) -> GenericResult: - return GenericResult(value=a) - - @dynamic_context - def other_method(self, *, b: str, c: float = 1.0) -> GenericResult: - return GenericResult(value=f"{b}-{c}") - - model = MultiMethodCallable() - - # Test __call__ - result1 = model(a=42) - self.assertEqual(result1.value, 42) - - # Test other_method (without Flow.call, just the dynamic_context wrapper) - # Need to create the context manually - other_ctx = model.other_method.__dynamic_context__ - ctx = other_ctx(b="hello", c=2.5) - result2 = model.other_method(ctx) - self.assertEqual(result2.value, "hello-2.5") - def test_context_type_property_works(self): - """Test that type_ property works on the dynamic context.""" + """Test that type_ property works on the auto context.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x) - dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ - ctx = dyn_ctx(x=42) + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + ctx = auto_ctx(x=42) # type_ should work and be importable type_path = str(ctx.type_) self.assertIn("_Local_", type_path) - self.assertEqual(ctx.type_.object, dyn_ctx) + self.assertEqual(ctx.type_.object, auto_ctx) def test_complex_field_types(self): - """Test dynamic_context with complex field types.""" + """Test auto_context with complex field types.""" from typing import List, Optional - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__( self, *, @@ -1013,7 +976,7 @@ def __call__( total = sum(items) + count return GenericResult(value=f"{name}:{total}" if name else str(total)) - model = DynamicCallable() + model = AutoContextCallable() result = model(items=[1, 2, 3], name="test", count=10) self.assertEqual(result.value, "test:16") @@ -1021,146 +984,43 @@ def __call__( result = model(items=[5, 5]) self.assertEqual(result.value, "10") - -class TestFlowDynamicCall(TestCase): - """Tests for @Flow.dynamic_call decorator.""" - - def test_basic_usage(self): - """Test basic @Flow.dynamic_call usage.""" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, x: int, y: str = "default") -> GenericResult: - return GenericResult(value=f"{x}-{y}") - - model = DynamicCallable() - - result = model(x=42, y="hello") - self.assertEqual(result.value, "42-hello") - - result = model(x=10) - self.assertEqual(result.value, "10-default") - - def test_dynamic_context_attributes_preserved(self): - """Test that __dynamic_context__ and __result_type__ are directly accessible.""" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, x: int) -> GenericResult: - return GenericResult(value=x) - - # Should be directly accessible without traversing __wrapped__ chain - method = DynamicCallable.__call__ - self.assertTrue(hasattr(method, "__dynamic_context__")) - self.assertTrue(hasattr(method, "__result_type__")) - self.assertTrue(issubclass(method.__dynamic_context__, ContextBase)) - self.assertEqual(method.__result_type__, GenericResult) - - def test_model_result_type_property(self): - """Test that model.result_type returns correct type for dynamic contexts.""" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, x: int) -> GenericResult: - return GenericResult(value=x) - - model = DynamicCallable() - self.assertEqual(model.result_type, GenericResult) - - def test_with_parent_context(self): - """Test @Flow.dynamic_call with parent context.""" - - class ParentContext(ContextBase): - base_value: str = "base" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call(parent=ParentContext) - def __call__(self, *, x: int, base_value: str) -> GenericResult: - return GenericResult(value=f"{x}-{base_value}") - - model = DynamicCallable() - - # Get dynamic context by traversing __wrapped__ chain - dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) - - # Should inherit from ParentContext - self.assertTrue(issubclass(dyn_ctx, ParentContext)) - - # Call should work, uses parent default - result = model(x=42, base_value="custom") - self.assertEqual(result.value, "42-custom") - def test_with_flow_options(self): - """Test @Flow.dynamic_call with FlowOptions parameters.""" + """Test auto_context with FlowOptions parameters.""" - class DynamicCallable(CallableModel): - @Flow.dynamic_call(validate_result=False) + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True, validate_result=False) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x) - model = DynamicCallable() + model = AutoContextCallable() result = model(x=42) self.assertEqual(result.value, 42) - def test_cloudpickle_roundtrip(self): - """Test cloudpickle roundtrip with @Flow.dynamic_call.""" - - class DynamicCallable(CallableModel): - multiplier: int = 2 - - @Flow.dynamic_call - def __call__(self, *, x: int) -> GenericResult: - return GenericResult(value=x * self.multiplier) - - model = DynamicCallable(multiplier=3) - restored = rcploads(rcpdumps(model)) - - result = restored(x=10) - self.assertEqual(result.value, 30) - - def test_ray_task(self): - """Test @Flow.dynamic_call in Ray task.""" - - class DynamicCallable(CallableModel): - factor: int = 2 - - @Flow.dynamic_call - def __call__(self, *, x: int, y: int = 1) -> GenericResult: - return GenericResult(value=(x + y) * self.factor) - - @ray.remote - def run_callable(model, **kwargs): - return model(**kwargs).value - - model = DynamicCallable(factor=5) + def test_error_without_auto_context(self): + """Test that using kwargs signature without auto_context raises an error.""" - with ray.init(num_cpus=1): - result = ray.get(run_callable.remote(model, x=10, y=2)) - - self.assertEqual(result, 60) - - def test_dynamic_context_is_registered(self): - """Test that the dynamic context from @Flow.dynamic_call is registered.""" + class BadCallable(CallableModel): + @Flow.call # Missing auto_context=True! + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, value: int) -> GenericResult: - return GenericResult(value=value) + # Error happens at instantiation time when _check_signature validates + with self.assertRaises(ValueError) as cm: + BadCallable() - # Find dynamic context by traversing __wrapped__ chain - dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) + # Should fail because __call__ must take a single argument named 'context' + error_msg = str(cm.exception) + self.assertIn("__call__", error_msg) + self.assertIn("context", error_msg) - self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) - self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + def test_invalid_auto_context_value(self): + """Test that invalid auto_context values raise TypeError with helpful message.""" + with self.assertRaises(TypeError) as cm: + @Flow.call(auto_context="invalid") + def bad_func(self, *, x: int) -> GenericResult: + return GenericResult(value=x) -def _find_dynamic_context(func): - """Helper to find __dynamic_context__ by traversing the __wrapped__ chain.""" - visited = set() - current = func - while current is not None and id(current) not in visited: - visited.add(id(current)) - if hasattr(current, "__dynamic_context__"): - return current.__dynamic_context__ - current = getattr(current, "__wrapped__", None) - return None + error_msg = str(cm.exception) + self.assertIn("auto_context must be False, True, or a ContextBase subclass", error_msg) + self.assertIn("invalid", error_msg) diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index 34f3f7e..dabf815 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -8,12 +8,12 @@ from .evaluators.util import MyDateCallable, MyResult -class MyDynamicDateCallable(CallableModel): - """Dynamic context version of MyDateCallable for testing evaluators.""" +class MyAutoContextDateCallable(CallableModel): + """Auto context version of MyDateCallable for testing evaluators.""" offset: int - @Flow.dynamic_call(parent=DateContext) + @Flow.call(auto_context=DateContext) def __call__(self, *, date: date) -> MyResult: return MyResult(x=date.day + self.offset) @@ -48,11 +48,11 @@ def test_evaluator_deps(self): @pytest.mark.parametrize( "callable_class", - [MyDateCallable, MyDynamicDateCallable], - ids=["standard", "dynamic"], + [MyDateCallable, MyAutoContextDateCallable], + ids=["standard", "auto_context"], ) class TestEvaluatorParametrized: - """Test evaluators work with both standard and dynamic context callables.""" + """Test evaluators work with both standard and auto_context callables.""" def test_evaluator_with_context_object(self, callable_class): """Test evaluator with a context object.""" diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py index 586b03f..dc2db55 100644 --- a/ccflow/tests/test_local_persistence.py +++ b/ccflow/tests/test_local_persistence.py @@ -1235,7 +1235,7 @@ class TestCreateCcflowModelCloudpickleCrossProcess: id="context_only", ), pytest.param( - # Dynamic context with CallableModel + # Runtime-created context with CallableModel """ from ray.cloudpickle import dump from ccflow import CallableModel, ContextBase, GenericResult, Flow