Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/flyte/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ContextData:
metadata: Optional[Tuple[Tuple[str, str], ...]] = None
preserve_original_types: bool = False
tracker: Any = None # ActionTracker instance (optional, set for TUI runs)
in_trace: bool = False # True when executing inside a @trace decorated function

def replace(self, **kwargs) -> ContextData:
return replace(self, **kwargs)
Expand Down Expand Up @@ -124,6 +125,13 @@ def is_task_context(self) -> bool:
"""
return self.data.task_context is not None

def is_in_trace(self) -> bool:
"""
Returns true if the context is in a trace context, else False
Returns: bool
"""
return self.data.in_trace

def __enter__(self):
"""Enter the context, setting it as the current context."""
self._token = root_context_var.set(self)
Expand Down
14 changes: 13 additions & 1 deletion src/flyte/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)

from flyte._pod import PodTemplate
from flyte.errors import RuntimeSystemError, RuntimeUserError
from flyte.errors import RuntimeSystemError, RuntimeUserError, TraceDoesNotAllowNestedTasksError

from ._cache import Cache, CacheRequest
from ._context import internal_ctx
Expand Down Expand Up @@ -264,6 +264,12 @@ async def my_new_parent_task(n: int) -> List[int]:
"""
ctx = internal_ctx()
if ctx.is_task_context():
if ctx.is_in_trace():
raise TraceDoesNotAllowNestedTasksError(
f"Task {self.name} is invoked from inside a `flyte.trace`. "
"You can continue using the task function, as a regular"
"python function using `task`.forward(...) facade."
)
from ._internal.controllers import get_controller

# If we are in a task context, that implies we are executing a Run.
Expand Down Expand Up @@ -307,6 +313,12 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R]
try:
ctx = internal_ctx()
if ctx.is_task_context():
if ctx.is_in_trace():
raise TraceDoesNotAllowNestedTasksError(
f"Task {self.name} is invoked from inside a `flyte.trace`. "
"You can continue using the task function, as a regular"
"python function using `task`.forward(...) facade."
)
# If we are in a task context, that implies we are executing a Run.
# In this scenario, we should submit the task to the controller.
# We will also check if we are not initialized, It is not expected to be not initialized
Expand Down
16 changes: 10 additions & 6 deletions src/flyte/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def wrapper_sync(*args: Any, **kwargs: Any) -> Any:

@functools.wraps(func)
async def wrapper_async(*args: Any, **kwargs: Any) -> Any:
from flyte._context import internal_ctx
from flyte._context import Context, internal_ctx

ctx = internal_ctx()
if ctx.is_task_context():
Expand All @@ -53,10 +53,12 @@ async def wrapper_async(*args: Any, **kwargs: Any) -> Any:
logger.debug(f"No existing trace info found for {func}, proceeding to execute.")
start_time = time.time()

# Create a new context with the trace's action ID
# Create a new context with the trace's action ID and mark as in_trace
# so that nested task calls run as pure Python instead of submitting to the controller.
# Note: ctx.data.task_context is guaranteed to be non-None by is_task_context() check above
trace_task_context = ctx.data.task_context.replace(action=info.action) # type: ignore[union-attr]
trace_context = ctx.replace_task_context(trace_task_context)
trace_data = ctx.data.replace(task_context=trace_task_context, in_trace=True)
trace_context = Context(trace_data)

# Execute function in trace context, then record outside it
error = None
Expand Down Expand Up @@ -90,7 +92,7 @@ def is_async_iterable(obj: Any) -> TypeGuard[Union[AsyncGenerator, AsyncIterator

@functools.wraps(func)
async def wrapper_async_iterator(*args: Any, **kwargs: Any) -> AsyncIterator[Any]:
from flyte._context import internal_ctx
from flyte._context import Context, internal_ctx

ctx = internal_ctx()
if ctx.is_task_context():
Expand All @@ -110,10 +112,12 @@ async def wrapper_async_iterator(*args: Any, **kwargs: Any) -> AsyncIterator[Any
raise info.error
start_time = time.time()

# Create a new context with the trace's action ID
# Create a new context with the trace's action ID and mark as in_trace
# so that nested task calls run as pure Python instead of submitting to the controller.
# Note: ctx.data.task_context is guaranteed to be non-None by is_task_context() check above
trace_task_context = ctx.data.task_context.replace(action=info.action) # type: ignore[union-attr]
trace_context = ctx.replace_task_context(trace_task_context)
trace_data = ctx.data.replace(task_context=trace_task_context, in_trace=True)
trace_context = Context(trace_data)

# Execute function in trace context, then record outside it
error = None
Expand Down
10 changes: 10 additions & 0 deletions src/flyte/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,13 @@ class CodeBundleError(RuntimeUserError):

def __init__(self, message: str):
super().__init__("CodeBundleError", message, "user")


class TraceDoesNotAllowNestedTasksError(RuntimeUserError):
"""
This error is raised when the user tries to use a task from within a trace. Tasks can be nested under tasks
not traces.
"""

def __init__(self, message: str):
super().__init__("TraceDoesNotAllowNestedTasksError", message)
72 changes: 71 additions & 1 deletion tests/internal/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pytest

import flyte
from flyte._context import internal_ctx
from flyte._context import Context, ContextData, internal_ctx
from flyte.errors import TraceDoesNotAllowNestedTasksError
from flyte.models import ActionID, RawDataPath, TaskContext
from flyte.report import Report
from flyte.syncify import syncify
Expand Down Expand Up @@ -217,3 +218,72 @@ def test_has_raw_data_priority():
assert internal_ctx().has_raw_data is True
# Also verify that raw_data returns the task context path
assert internal_ctx().raw_data.path == "/task/path"


# --- Tasks cannot be nested inside flyte.trace ---

env = flyte.TaskEnvironment(name="test_trace_nesting")


@env.task
async def async_child_task(x: int) -> int:
return x + 1


@env.task
def sync_child_task(x: int) -> int:
return x + 1


def _make_trace_context() -> Context:
"""Create a context that simulates being inside a @trace within a task."""
task_ctx = TaskContext(
action=ActionID(name="test"),
run_base_dir="test",
output_path="test",
raw_data_path=RawDataPath(path=""),
version="",
report=Report("test"),
)
return Context(data=ContextData(task_context=task_ctx, in_trace=True))


def test_task_call_raises_in_trace_context():
"""Calling a @task via __call__ inside a flyte.trace context must raise."""
trace_ctx = _make_trace_context()
with trace_ctx:
with pytest.raises(TraceDoesNotAllowNestedTasksError):
sync_child_task(1)


def test_async_task_call_raises_in_trace_context():
"""Calling an async @task via __call__ inside a flyte.trace context must raise."""
trace_ctx = _make_trace_context()
with trace_ctx:
with pytest.raises(TraceDoesNotAllowNestedTasksError):
async_child_task(1)


@pytest.mark.asyncio
async def test_task_aio_raises_in_trace_context():
"""Calling a @task via .aio() inside a flyte.trace context must raise."""
trace_ctx = _make_trace_context()
async with trace_ctx:
with pytest.raises(TraceDoesNotAllowNestedTasksError):
await async_child_task.aio(1)


@pytest.mark.asyncio
async def test_sync_task_aio_raises_in_trace_context():
"""Calling a sync @task via .aio() inside a flyte.trace context must raise."""
trace_ctx = _make_trace_context()
async with trace_ctx:
with pytest.raises(TraceDoesNotAllowNestedTasksError):
await sync_child_task.aio(1)


def test_task_forward_bypasses_trace_check():
"""Calling task.forward() should work even in trace context (it's the escape hatch)."""
trace_ctx = _make_trace_context()
with trace_ctx:
assert sync_child_task.forward(1) == 2
24 changes: 24 additions & 0 deletions tests/user_api/test_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,27 @@ async def traced_func() -> int:
assert trace_action_name != parent_action_name, (
f"Trace should have different action ID than parent. Trace: {trace_action_name}, Parent: {parent_action_name}"
)


@env.task
async def nested_child_task(x: int) -> int:
return x * 2


@flyte.trace
async def trace_that_calls_task(x: int) -> int:
return await nested_child_task(x)


@env.task
async def parent_calling_trace_with_task(x: int = 5) -> int:
return await trace_that_calls_task(x)


@pytest.mark.asyncio
async def test_task_nested_in_trace_raises():
"""Calling a @task inside a @flyte.trace must raise TraceDoesNotAllowNestedTasksError."""
await flyte.init.aio()
with pytest.raises(flyte.errors.RuntimeUserError) as excinfo:
flyte.run(parent_calling_trace_with_task, x=5)
assert excinfo.value.code == "TraceDoesNotAllowNestedTasksError"
Loading