diff --git a/src/flyte/_context.py b/src/flyte/_context.py index 60211a2bd..79fb71438 100644 --- a/src/flyte/_context.py +++ b/src/flyte/_context.py @@ -32,6 +32,7 @@ class ContextData: metadata: Optional[Tuple[Tuple[str, str], ...]] = None preserve_original_types: bool = False tracker: Any = None # ActionTracker instance (optional, set for TUI runs) + in_trace: bool = False # True when executing inside a @trace decorated function def replace(self, **kwargs) -> ContextData: return replace(self, **kwargs) @@ -124,6 +125,13 @@ def is_task_context(self) -> bool: """ return self.data.task_context is not None + def is_in_trace(self) -> bool: + """ + Returns true if the context is in a trace context, else False + Returns: bool + """ + return self.data.in_trace + def __enter__(self): """Enter the context, setting it as the current context.""" self._token = root_context_var.set(self) diff --git a/src/flyte/_task.py b/src/flyte/_task.py index 089bfa568..d141d1c06 100644 --- a/src/flyte/_task.py +++ b/src/flyte/_task.py @@ -24,7 +24,7 @@ ) from flyte._pod import PodTemplate -from flyte.errors import RuntimeSystemError, RuntimeUserError +from flyte.errors import RuntimeSystemError, RuntimeUserError, TraceDoesNotAllowNestedTasksError from ._cache import Cache, CacheRequest from ._context import internal_ctx @@ -264,6 +264,12 @@ async def my_new_parent_task(n: int) -> List[int]: """ ctx = internal_ctx() if ctx.is_task_context(): + if ctx.is_in_trace(): + raise TraceDoesNotAllowNestedTasksError( + f"Task {self.name} is invoked from inside a `flyte.trace`. " + "You can continue using the task function, as a regular" + "python function using `task`.forward(...) facade." + ) from ._internal.controllers import get_controller # If we are in a task context, that implies we are executing a Run. @@ -307,6 +313,12 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] try: ctx = internal_ctx() if ctx.is_task_context(): + if ctx.is_in_trace(): + raise TraceDoesNotAllowNestedTasksError( + f"Task {self.name} is invoked from inside a `flyte.trace`. " + "You can continue using the task function, as a regular" + "python function using `task`.forward(...) facade." + ) # If we are in a task context, that implies we are executing a Run. # In this scenario, we should submit the task to the controller. # We will also check if we are not initialized, It is not expected to be not initialized diff --git a/src/flyte/_trace.py b/src/flyte/_trace.py index 8ac63cfcd..cf34f00d2 100644 --- a/src/flyte/_trace.py +++ b/src/flyte/_trace.py @@ -31,7 +31,7 @@ def wrapper_sync(*args: Any, **kwargs: Any) -> Any: @functools.wraps(func) async def wrapper_async(*args: Any, **kwargs: Any) -> Any: - from flyte._context import internal_ctx + from flyte._context import Context, internal_ctx ctx = internal_ctx() if ctx.is_task_context(): @@ -53,10 +53,12 @@ async def wrapper_async(*args: Any, **kwargs: Any) -> Any: logger.debug(f"No existing trace info found for {func}, proceeding to execute.") start_time = time.time() - # Create a new context with the trace's action ID + # Create a new context with the trace's action ID and mark as in_trace + # so that nested task calls run as pure Python instead of submitting to the controller. # Note: ctx.data.task_context is guaranteed to be non-None by is_task_context() check above trace_task_context = ctx.data.task_context.replace(action=info.action) # type: ignore[union-attr] - trace_context = ctx.replace_task_context(trace_task_context) + trace_data = ctx.data.replace(task_context=trace_task_context, in_trace=True) + trace_context = Context(trace_data) # Execute function in trace context, then record outside it error = None @@ -90,7 +92,7 @@ def is_async_iterable(obj: Any) -> TypeGuard[Union[AsyncGenerator, AsyncIterator @functools.wraps(func) async def wrapper_async_iterator(*args: Any, **kwargs: Any) -> AsyncIterator[Any]: - from flyte._context import internal_ctx + from flyte._context import Context, internal_ctx ctx = internal_ctx() if ctx.is_task_context(): @@ -110,10 +112,12 @@ async def wrapper_async_iterator(*args: Any, **kwargs: Any) -> AsyncIterator[Any raise info.error start_time = time.time() - # Create a new context with the trace's action ID + # Create a new context with the trace's action ID and mark as in_trace + # so that nested task calls run as pure Python instead of submitting to the controller. # Note: ctx.data.task_context is guaranteed to be non-None by is_task_context() check above trace_task_context = ctx.data.task_context.replace(action=info.action) # type: ignore[union-attr] - trace_context = ctx.replace_task_context(trace_task_context) + trace_data = ctx.data.replace(task_context=trace_task_context, in_trace=True) + trace_context = Context(trace_data) # Execute function in trace context, then record outside it error = None diff --git a/src/flyte/errors.py b/src/flyte/errors.py index 4f81b8bd9..28bf5fcb0 100644 --- a/src/flyte/errors.py +++ b/src/flyte/errors.py @@ -283,3 +283,13 @@ class CodeBundleError(RuntimeUserError): def __init__(self, message: str): super().__init__("CodeBundleError", message, "user") + + +class TraceDoesNotAllowNestedTasksError(RuntimeUserError): + """ + This error is raised when the user tries to use a task from within a trace. Tasks can be nested under tasks + not traces. + """ + + def __init__(self, message: str): + super().__init__("TraceDoesNotAllowNestedTasksError", message) diff --git a/tests/internal/test_context.py b/tests/internal/test_context.py index d7d03ced5..a6a56cd8a 100644 --- a/tests/internal/test_context.py +++ b/tests/internal/test_context.py @@ -3,7 +3,8 @@ import pytest import flyte -from flyte._context import internal_ctx +from flyte._context import Context, ContextData, internal_ctx +from flyte.errors import TraceDoesNotAllowNestedTasksError from flyte.models import ActionID, RawDataPath, TaskContext from flyte.report import Report from flyte.syncify import syncify @@ -217,3 +218,72 @@ def test_has_raw_data_priority(): assert internal_ctx().has_raw_data is True # Also verify that raw_data returns the task context path assert internal_ctx().raw_data.path == "/task/path" + + +# --- Tasks cannot be nested inside flyte.trace --- + +env = flyte.TaskEnvironment(name="test_trace_nesting") + + +@env.task +async def async_child_task(x: int) -> int: + return x + 1 + + +@env.task +def sync_child_task(x: int) -> int: + return x + 1 + + +def _make_trace_context() -> Context: + """Create a context that simulates being inside a @trace within a task.""" + task_ctx = TaskContext( + action=ActionID(name="test"), + run_base_dir="test", + output_path="test", + raw_data_path=RawDataPath(path=""), + version="", + report=Report("test"), + ) + return Context(data=ContextData(task_context=task_ctx, in_trace=True)) + + +def test_task_call_raises_in_trace_context(): + """Calling a @task via __call__ inside a flyte.trace context must raise.""" + trace_ctx = _make_trace_context() + with trace_ctx: + with pytest.raises(TraceDoesNotAllowNestedTasksError): + sync_child_task(1) + + +def test_async_task_call_raises_in_trace_context(): + """Calling an async @task via __call__ inside a flyte.trace context must raise.""" + trace_ctx = _make_trace_context() + with trace_ctx: + with pytest.raises(TraceDoesNotAllowNestedTasksError): + async_child_task(1) + + +@pytest.mark.asyncio +async def test_task_aio_raises_in_trace_context(): + """Calling a @task via .aio() inside a flyte.trace context must raise.""" + trace_ctx = _make_trace_context() + async with trace_ctx: + with pytest.raises(TraceDoesNotAllowNestedTasksError): + await async_child_task.aio(1) + + +@pytest.mark.asyncio +async def test_sync_task_aio_raises_in_trace_context(): + """Calling a sync @task via .aio() inside a flyte.trace context must raise.""" + trace_ctx = _make_trace_context() + async with trace_ctx: + with pytest.raises(TraceDoesNotAllowNestedTasksError): + await sync_child_task.aio(1) + + +def test_task_forward_bypasses_trace_check(): + """Calling task.forward() should work even in trace context (it's the escape hatch).""" + trace_ctx = _make_trace_context() + with trace_ctx: + assert sync_child_task.forward(1) == 2 diff --git a/tests/user_api/test_traces.py b/tests/user_api/test_traces.py index 6d61fe4c0..6ce90cafe 100644 --- a/tests/user_api/test_traces.py +++ b/tests/user_api/test_traces.py @@ -107,3 +107,27 @@ async def traced_func() -> int: assert trace_action_name != parent_action_name, ( f"Trace should have different action ID than parent. Trace: {trace_action_name}, Parent: {parent_action_name}" ) + + +@env.task +async def nested_child_task(x: int) -> int: + return x * 2 + + +@flyte.trace +async def trace_that_calls_task(x: int) -> int: + return await nested_child_task(x) + + +@env.task +async def parent_calling_trace_with_task(x: int = 5) -> int: + return await trace_that_calls_task(x) + + +@pytest.mark.asyncio +async def test_task_nested_in_trace_raises(): + """Calling a @task inside a @flyte.trace must raise TraceDoesNotAllowNestedTasksError.""" + await flyte.init.aio() + with pytest.raises(flyte.errors.RuntimeUserError) as excinfo: + flyte.run(parent_calling_trace_with_task, x=5) + assert excinfo.value.code == "TraceDoesNotAllowNestedTasksError"