|
3 | 3 | import pytest |
4 | 4 |
|
5 | 5 | import flyte |
6 | | -from flyte._context import internal_ctx |
| 6 | +from flyte._context import Context, ContextData, internal_ctx |
| 7 | +from flyte.errors import TraceDoesNotAllowNestedTasksError |
7 | 8 | from flyte.models import ActionID, RawDataPath, TaskContext |
8 | 9 | from flyte.report import Report |
9 | 10 | from flyte.syncify import syncify |
@@ -217,3 +218,72 @@ def test_has_raw_data_priority(): |
217 | 218 | assert internal_ctx().has_raw_data is True |
218 | 219 | # Also verify that raw_data returns the task context path |
219 | 220 | assert internal_ctx().raw_data.path == "/task/path" |
| 221 | + |
| 222 | + |
| 223 | +# --- Tasks cannot be nested inside flyte.trace --- |
| 224 | + |
| 225 | +env = flyte.TaskEnvironment(name="test_trace_nesting") |
| 226 | + |
| 227 | + |
| 228 | +@env.task |
| 229 | +async def async_child_task(x: int) -> int: |
| 230 | + return x + 1 |
| 231 | + |
| 232 | + |
| 233 | +@env.task |
| 234 | +def sync_child_task(x: int) -> int: |
| 235 | + return x + 1 |
| 236 | + |
| 237 | + |
| 238 | +def _make_trace_context() -> Context: |
| 239 | + """Create a context that simulates being inside a @trace within a task.""" |
| 240 | + task_ctx = TaskContext( |
| 241 | + action=ActionID(name="test"), |
| 242 | + run_base_dir="test", |
| 243 | + output_path="test", |
| 244 | + raw_data_path=RawDataPath(path=""), |
| 245 | + version="", |
| 246 | + report=Report("test"), |
| 247 | + ) |
| 248 | + return Context(data=ContextData(task_context=task_ctx, in_trace=True)) |
| 249 | + |
| 250 | + |
| 251 | +def test_task_call_raises_in_trace_context(): |
| 252 | + """Calling a @task via __call__ inside a flyte.trace context must raise.""" |
| 253 | + trace_ctx = _make_trace_context() |
| 254 | + with trace_ctx: |
| 255 | + with pytest.raises(TraceDoesNotAllowNestedTasksError): |
| 256 | + sync_child_task(1) |
| 257 | + |
| 258 | + |
| 259 | +def test_async_task_call_raises_in_trace_context(): |
| 260 | + """Calling an async @task via __call__ inside a flyte.trace context must raise.""" |
| 261 | + trace_ctx = _make_trace_context() |
| 262 | + with trace_ctx: |
| 263 | + with pytest.raises(TraceDoesNotAllowNestedTasksError): |
| 264 | + async_child_task(1) |
| 265 | + |
| 266 | + |
| 267 | +@pytest.mark.asyncio |
| 268 | +async def test_task_aio_raises_in_trace_context(): |
| 269 | + """Calling a @task via .aio() inside a flyte.trace context must raise.""" |
| 270 | + trace_ctx = _make_trace_context() |
| 271 | + async with trace_ctx: |
| 272 | + with pytest.raises(TraceDoesNotAllowNestedTasksError): |
| 273 | + await async_child_task.aio(1) |
| 274 | + |
| 275 | + |
| 276 | +@pytest.mark.asyncio |
| 277 | +async def test_sync_task_aio_raises_in_trace_context(): |
| 278 | + """Calling a sync @task via .aio() inside a flyte.trace context must raise.""" |
| 279 | + trace_ctx = _make_trace_context() |
| 280 | + async with trace_ctx: |
| 281 | + with pytest.raises(TraceDoesNotAllowNestedTasksError): |
| 282 | + await sync_child_task.aio(1) |
| 283 | + |
| 284 | + |
| 285 | +def test_task_forward_bypasses_trace_check(): |
| 286 | + """Calling task.forward() should work even in trace context (it's the escape hatch).""" |
| 287 | + trace_ctx = _make_trace_context() |
| 288 | + with trace_ctx: |
| 289 | + assert sync_child_task.forward(1) == 2 |
0 commit comments