Skip to content

Commit 07dc8c7

Browse files
samhita-allakumare3claude
authored
Tasks cannot be nested within a trace! (#636)
`task_context` is preserved inside a trace but any task called from within a trace will run as plain python. --------- Signed-off-by: Samhita Alla <aallasamhita@gmail.com> Signed-off-by: Ketan Umare <kumare3@users.noreply.github.com> Co-authored-by: Ketan Umare <kumare3@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 05ab2d1 commit 07dc8c7

File tree

6 files changed

+136
-8
lines changed

6 files changed

+136
-8
lines changed

src/flyte/_context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ContextData:
3232
metadata: Optional[Tuple[Tuple[str, str], ...]] = None
3333
preserve_original_types: bool = False
3434
tracker: Any = None # ActionTracker instance (optional, set for TUI runs)
35+
in_trace: bool = False # True when executing inside a @trace decorated function
3536

3637
def replace(self, **kwargs) -> ContextData:
3738
return replace(self, **kwargs)
@@ -124,6 +125,13 @@ def is_task_context(self) -> bool:
124125
"""
125126
return self.data.task_context is not None
126127

128+
def is_in_trace(self) -> bool:
129+
"""
130+
Returns true if the context is in a trace context, else False
131+
Returns: bool
132+
"""
133+
return self.data.in_trace
134+
127135
def __enter__(self):
128136
"""Enter the context, setting it as the current context."""
129137
self._token = root_context_var.set(self)

src/flyte/_task.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525

2626
from flyte._pod import PodTemplate
27-
from flyte.errors import RuntimeSystemError, RuntimeUserError
27+
from flyte.errors import RuntimeSystemError, RuntimeUserError, TraceDoesNotAllowNestedTasksError
2828

2929
from ._cache import Cache, CacheRequest
3030
from ._context import internal_ctx
@@ -264,6 +264,12 @@ async def my_new_parent_task(n: int) -> List[int]:
264264
"""
265265
ctx = internal_ctx()
266266
if ctx.is_task_context():
267+
if ctx.is_in_trace():
268+
raise TraceDoesNotAllowNestedTasksError(
269+
f"Task {self.name} is invoked from inside a `flyte.trace`. "
270+
"You can continue using the task function, as a regular"
271+
"python function using `task`.forward(...) facade."
272+
)
267273
from ._internal.controllers import get_controller
268274

269275
# If we are in a task context, that implies we are executing a Run.
@@ -307,6 +313,12 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R]
307313
try:
308314
ctx = internal_ctx()
309315
if ctx.is_task_context():
316+
if ctx.is_in_trace():
317+
raise TraceDoesNotAllowNestedTasksError(
318+
f"Task {self.name} is invoked from inside a `flyte.trace`. "
319+
"You can continue using the task function, as a regular"
320+
"python function using `task`.forward(...) facade."
321+
)
310322
# If we are in a task context, that implies we are executing a Run.
311323
# In this scenario, we should submit the task to the controller.
312324
# We will also check if we are not initialized, It is not expected to be not initialized

src/flyte/_trace.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def wrapper_sync(*args: Any, **kwargs: Any) -> Any:
3131

3232
@functools.wraps(func)
3333
async def wrapper_async(*args: Any, **kwargs: Any) -> Any:
34-
from flyte._context import internal_ctx
34+
from flyte._context import Context, internal_ctx
3535

3636
ctx = internal_ctx()
3737
if ctx.is_task_context():
@@ -53,10 +53,12 @@ async def wrapper_async(*args: Any, **kwargs: Any) -> Any:
5353
logger.debug(f"No existing trace info found for {func}, proceeding to execute.")
5454
start_time = time.time()
5555

56-
# Create a new context with the trace's action ID
56+
# Create a new context with the trace's action ID and mark as in_trace
57+
# so that nested task calls run as pure Python instead of submitting to the controller.
5758
# Note: ctx.data.task_context is guaranteed to be non-None by is_task_context() check above
5859
trace_task_context = ctx.data.task_context.replace(action=info.action) # type: ignore[union-attr]
59-
trace_context = ctx.replace_task_context(trace_task_context)
60+
trace_data = ctx.data.replace(task_context=trace_task_context, in_trace=True)
61+
trace_context = Context(trace_data)
6062

6163
# Execute function in trace context, then record outside it
6264
error = None
@@ -90,7 +92,7 @@ def is_async_iterable(obj: Any) -> TypeGuard[Union[AsyncGenerator, AsyncIterator
9092

9193
@functools.wraps(func)
9294
async def wrapper_async_iterator(*args: Any, **kwargs: Any) -> AsyncIterator[Any]:
93-
from flyte._context import internal_ctx
95+
from flyte._context import Context, internal_ctx
9496

9597
ctx = internal_ctx()
9698
if ctx.is_task_context():
@@ -110,10 +112,12 @@ async def wrapper_async_iterator(*args: Any, **kwargs: Any) -> AsyncIterator[Any
110112
raise info.error
111113
start_time = time.time()
112114

113-
# Create a new context with the trace's action ID
115+
# Create a new context with the trace's action ID and mark as in_trace
116+
# so that nested task calls run as pure Python instead of submitting to the controller.
114117
# Note: ctx.data.task_context is guaranteed to be non-None by is_task_context() check above
115118
trace_task_context = ctx.data.task_context.replace(action=info.action) # type: ignore[union-attr]
116-
trace_context = ctx.replace_task_context(trace_task_context)
119+
trace_data = ctx.data.replace(task_context=trace_task_context, in_trace=True)
120+
trace_context = Context(trace_data)
117121

118122
# Execute function in trace context, then record outside it
119123
error = None

src/flyte/errors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,13 @@ class CodeBundleError(RuntimeUserError):
283283

284284
def __init__(self, message: str):
285285
super().__init__("CodeBundleError", message, "user")
286+
287+
288+
class TraceDoesNotAllowNestedTasksError(RuntimeUserError):
289+
"""
290+
This error is raised when the user tries to use a task from within a trace. Tasks can be nested under tasks
291+
not traces.
292+
"""
293+
294+
def __init__(self, message: str):
295+
super().__init__("TraceDoesNotAllowNestedTasksError", message)

tests/internal/test_context.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import pytest
44

55
import flyte
6-
from flyte._context import internal_ctx
6+
from flyte._context import Context, ContextData, internal_ctx
7+
from flyte.errors import TraceDoesNotAllowNestedTasksError
78
from flyte.models import ActionID, RawDataPath, TaskContext
89
from flyte.report import Report
910
from flyte.syncify import syncify
@@ -217,3 +218,72 @@ def test_has_raw_data_priority():
217218
assert internal_ctx().has_raw_data is True
218219
# Also verify that raw_data returns the task context path
219220
assert internal_ctx().raw_data.path == "/task/path"
221+
222+
223+
# --- Tasks cannot be nested inside flyte.trace ---
224+
225+
env = flyte.TaskEnvironment(name="test_trace_nesting")
226+
227+
228+
@env.task
229+
async def async_child_task(x: int) -> int:
230+
return x + 1
231+
232+
233+
@env.task
234+
def sync_child_task(x: int) -> int:
235+
return x + 1
236+
237+
238+
def _make_trace_context() -> Context:
239+
"""Create a context that simulates being inside a @trace within a task."""
240+
task_ctx = TaskContext(
241+
action=ActionID(name="test"),
242+
run_base_dir="test",
243+
output_path="test",
244+
raw_data_path=RawDataPath(path=""),
245+
version="",
246+
report=Report("test"),
247+
)
248+
return Context(data=ContextData(task_context=task_ctx, in_trace=True))
249+
250+
251+
def test_task_call_raises_in_trace_context():
252+
"""Calling a @task via __call__ inside a flyte.trace context must raise."""
253+
trace_ctx = _make_trace_context()
254+
with trace_ctx:
255+
with pytest.raises(TraceDoesNotAllowNestedTasksError):
256+
sync_child_task(1)
257+
258+
259+
def test_async_task_call_raises_in_trace_context():
260+
"""Calling an async @task via __call__ inside a flyte.trace context must raise."""
261+
trace_ctx = _make_trace_context()
262+
with trace_ctx:
263+
with pytest.raises(TraceDoesNotAllowNestedTasksError):
264+
async_child_task(1)
265+
266+
267+
@pytest.mark.asyncio
268+
async def test_task_aio_raises_in_trace_context():
269+
"""Calling a @task via .aio() inside a flyte.trace context must raise."""
270+
trace_ctx = _make_trace_context()
271+
async with trace_ctx:
272+
with pytest.raises(TraceDoesNotAllowNestedTasksError):
273+
await async_child_task.aio(1)
274+
275+
276+
@pytest.mark.asyncio
277+
async def test_sync_task_aio_raises_in_trace_context():
278+
"""Calling a sync @task via .aio() inside a flyte.trace context must raise."""
279+
trace_ctx = _make_trace_context()
280+
async with trace_ctx:
281+
with pytest.raises(TraceDoesNotAllowNestedTasksError):
282+
await sync_child_task.aio(1)
283+
284+
285+
def test_task_forward_bypasses_trace_check():
286+
"""Calling task.forward() should work even in trace context (it's the escape hatch)."""
287+
trace_ctx = _make_trace_context()
288+
with trace_ctx:
289+
assert sync_child_task.forward(1) == 2

tests/user_api/test_traces.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,27 @@ async def traced_func() -> int:
107107
assert trace_action_name != parent_action_name, (
108108
f"Trace should have different action ID than parent. Trace: {trace_action_name}, Parent: {parent_action_name}"
109109
)
110+
111+
112+
@env.task
113+
async def nested_child_task(x: int) -> int:
114+
return x * 2
115+
116+
117+
@flyte.trace
118+
async def trace_that_calls_task(x: int) -> int:
119+
return await nested_child_task(x)
120+
121+
122+
@env.task
123+
async def parent_calling_trace_with_task(x: int = 5) -> int:
124+
return await trace_that_calls_task(x)
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_task_nested_in_trace_raises():
129+
"""Calling a @task inside a @flyte.trace must raise TraceDoesNotAllowNestedTasksError."""
130+
await flyte.init.aio()
131+
with pytest.raises(flyte.errors.RuntimeUserError) as excinfo:
132+
flyte.run(parent_calling_trace_with_task, x=5)
133+
assert excinfo.value.code == "TraceDoesNotAllowNestedTasksError"

0 commit comments

Comments
 (0)