Skip to content

Commit 20916be

Browse files
iamemiliomergify[bot]cdoern
authored
fix: prevent OTel context leak in fire-and-forget background tasks (#5168)
## What's the problem? When you look at a trace in Jaeger, you expect it to show what happened during a single request. Instead, we found traces that looked like this during load testing: - A request that took **5 seconds** showed a trace lasting **62 seconds** - That trace contained **2,594 spans**, including **334 database writes that belonged to completely different requests** The trace was essentially garbage -- you couldn't tell what actually happened during the request vs. what leaked in from other requests happening at the same time. ## Why does this happen? The server uses background worker tasks to write data to the database without blocking the API response. These workers are long-lived -- they start up once and process a shared queue forever. The problem is how Python's `asyncio.create_task` works: it copies all context variables (including the OpenTelemetry trace context) at the moment the task is created. So whichever API request happens to **first** trigger worker creation permanently stamps its trace ID onto that worker. Every database write the worker processes from that point forward -- regardless of which request it came from -- gets attributed to that original request's trace. ``` Request A arrives → spawns worker → worker inherits trace A Request B arrives → enqueues work → worker processes it under trace A ← wrong! Request C arrives → enqueues work → worker processes it under trace A ← wrong! ...forever ``` ## How does this fix it? Two changes working together: **1. Workers start with a clean slate.** A new helper (`create_task_with_detached_otel_context`) creates the worker task with an empty trace context, so it doesn't permanently inherit any request's identity. **2. Each queue item carries its own trace context.** When a request enqueues work, it snapshots its current trace context and attaches it to the queue item. When the worker picks up that item, it temporarily activates the captured context for the duration of that work, then returns to a clean state before processing the next item. ``` Request A arrives → enqueues work with trace A context Request B arrives → enqueues work with trace B context Worker (no trace) → picks up item A → activates trace A → writes to DB → deactivates → picks up item B → activates trace B → writes to DB → deactivates ``` The result: each database write shows up under the correct request's trace. No inflation, no cross-contamination. ## What changed? | File | What it does | |------|-------------| | `core/task.py` (new) | Three utilities: `create_task_with_detached_otel_context` (start tasks clean), `capture_otel_context` (snapshot current context), `activate_otel_context` (temporarily restore a captured context) | | `inference_store.py` | Queue items now carry the OTel context; workers activate it per-item before writing | | `openai_responses.py` | Same pattern for the responses background worker | ## How is this tested? **14 new tests** across three files: - **`test_task.py`** (9 tests) -- validates the primitives: detached tasks get clean context, captured context can be re-activated, context flows correctly through a queue, and two requests don't contaminate each other - **`test_inference_store.py`** (2 tests) -- end-to-end with a real SQLite-backed InferenceStore: simulates two API requests, lets the queue + workers process the writes, and asserts each write lands in the correct trace (this directly reproduces the original bug) - **`test_responses_background.py`** (3 tests) -- same validation for the responses worker, plus a test proving that error-handling DB writes (marking a response as failed) are also attributed to the correct trace ## Test plan - [x] All 14 new unit tests pass - [x] All existing unit tests unaffected - [x] Inference and Responses API tests that use in memory OTEL span collectors pass --------- Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Charlie Doern <cdoern@redhat.com>
1 parent 1bdf971 commit 20916be

File tree

6 files changed

+774
-75
lines changed

6 files changed

+774
-75
lines changed

src/llama_stack/core/task.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import asyncio
8+
from collections.abc import Coroutine
9+
from contextlib import contextmanager
10+
from typing import Any
11+
12+
from opentelemetry import context as otel_context
13+
14+
15+
def create_task_with_detached_otel_context(coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]:
16+
"""Create an asyncio task that does not inherit the current OpenTelemetry trace context.
17+
18+
asyncio.create_task copies all contextvars at creation time, which causes
19+
fire-and-forget or long-lived background tasks to be attributed to whatever
20+
request happened to spawn them. This inflates trace durations and bundles
21+
unrelated DB operations under the wrong trace.
22+
23+
This helper temporarily clears the OTel context before creating the task,
24+
then immediately restores it so the calling coroutine is unaffected.
25+
"""
26+
token = otel_context.attach(otel_context.Context())
27+
try:
28+
task = asyncio.create_task(coro)
29+
finally:
30+
otel_context.detach(token)
31+
return task
32+
33+
34+
def capture_otel_context() -> otel_context.Context:
35+
"""Snapshot the current OTel context for later use in a different task."""
36+
return otel_context.get_current()
37+
38+
39+
@contextmanager
40+
def activate_otel_context(ctx: otel_context.Context):
41+
"""Temporarily activate a previously captured OTel context.
42+
43+
Use this in worker loops that run with a detached (empty) context to
44+
attribute work back to the originating request's trace.
45+
"""
46+
token = otel_context.attach(ctx)
47+
try:
48+
yield
49+
finally:
50+
otel_context.detach(token)

src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py

Lines changed: 82 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99
import time
1010
import uuid
1111
from collections.abc import AsyncIterator
12+
from dataclasses import dataclass, field
1213

14+
from opentelemetry import context as otel_context
1315
from pydantic import BaseModel, TypeAdapter
1416

1517
from llama_stack.core.conversations.validation import CONVERSATION_ID_PATTERN
18+
from llama_stack.core.task import activate_otel_context, capture_otel_context, create_task_with_detached_otel_context
1619
from llama_stack.log import get_logger
1720
from llama_stack.providers.utils.responses.responses_store import (
1821
ResponsesStore,
@@ -82,6 +85,14 @@
8285
BACKGROUND_NUM_WORKERS = 10
8386

8487

88+
@dataclass
89+
class _BackgroundWorkItem:
90+
"""Typed queue item for background response processing."""
91+
92+
otel_context: otel_context.Context
93+
kwargs: dict = field(default_factory=dict)
94+
95+
8596
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
8697
input_items: ListOpenAIResponseInputItem
8798
response: OpenAIResponseObject
@@ -118,7 +129,7 @@ def __init__(
118129
self.prompts_api = prompts_api
119130
self.files_api = files_api
120131
self.connectors_api = connectors_api
121-
self._background_queue: asyncio.Queue = asyncio.Queue(maxsize=BACKGROUND_QUEUE_MAX_SIZE)
132+
self._background_queue: asyncio.Queue[_BackgroundWorkItem] = asyncio.Queue(maxsize=BACKGROUND_QUEUE_MAX_SIZE)
122133
self._background_worker_tasks: set[asyncio.Task] = set()
123134

124135
async def initialize(self) -> None:
@@ -133,7 +144,7 @@ async def initialize(self) -> None:
133144
async def _ensure_workers_started(self) -> None:
134145
"""Start background workers in the current event loop if not already running."""
135146
for _ in range(BACKGROUND_NUM_WORKERS - len(self._background_worker_tasks)):
136-
task = asyncio.create_task(self._background_worker())
147+
task = create_task_with_detached_otel_context(self._background_worker())
137148
self._background_worker_tasks.add(task)
138149
task.add_done_callback(self._background_worker_tasks.discard)
139150

@@ -146,48 +157,49 @@ async def shutdown(self) -> None:
146157
async def _background_worker(self) -> None:
147158
"""Worker coroutine that pulls items from the queue and processes them."""
148159
while True:
149-
kwargs = await self._background_queue.get()
150-
try:
151-
await asyncio.wait_for(
152-
self._run_background_response_loop(**kwargs),
153-
timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS,
154-
)
155-
except TimeoutError:
156-
response_id = kwargs["response_id"]
157-
logger.exception(
158-
f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s"
159-
)
160+
item = await self._background_queue.get()
161+
with activate_otel_context(item.otel_context):
160162
try:
161-
existing = await self.responses_store.get_response_object(response_id)
162-
existing.status = "failed"
163-
existing.error = OpenAIResponseError(
164-
code="processing_error",
165-
message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s",
163+
await asyncio.wait_for(
164+
self._run_background_response_loop(**item.kwargs),
165+
timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS,
166166
)
167-
await self.responses_store.update_response_object(existing)
168-
except Exception:
167+
except TimeoutError:
168+
response_id = item.kwargs["response_id"]
169169
logger.exception(
170-
f"Failed to update response {response_id} with timeout status. "
171-
"Client polling this response will not see the failure."
172-
)
173-
except Exception as e:
174-
response_id = kwargs["response_id"]
175-
logger.exception(f"Error processing background response {response_id}")
176-
try:
177-
existing = await self.responses_store.get_response_object(response_id)
178-
existing.status = "failed"
179-
existing.error = OpenAIResponseError(
180-
code="processing_error",
181-
message=str(e),
170+
f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s"
182171
)
183-
await self.responses_store.update_response_object(existing)
184-
except Exception:
185-
logger.exception(
186-
f"Failed to update response {response_id} with error status. "
187-
"Client polling this response will not see the failure."
188-
)
189-
finally:
190-
self._background_queue.task_done()
172+
try:
173+
existing = await self.responses_store.get_response_object(response_id)
174+
existing.status = "failed"
175+
existing.error = OpenAIResponseError(
176+
code="processing_error",
177+
message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s",
178+
)
179+
await self.responses_store.update_response_object(existing)
180+
except Exception:
181+
logger.exception(
182+
f"Failed to update response {response_id} with timeout status. "
183+
"Client polling this response will not see the failure."
184+
)
185+
except Exception as e:
186+
response_id = item.kwargs["response_id"]
187+
logger.exception(f"Error processing background response {response_id}")
188+
try:
189+
existing = await self.responses_store.get_response_object(response_id)
190+
existing.status = "failed"
191+
existing.error = OpenAIResponseError(
192+
code="processing_error",
193+
message=str(e),
194+
)
195+
await self.responses_store.update_response_object(existing)
196+
except Exception:
197+
logger.exception(
198+
f"Failed to update response {response_id} with error status. "
199+
"Client polling this response will not see the failure."
200+
)
201+
finally:
202+
self._background_queue.task_done()
191203

192204
async def _prepend_previous_response(
193205
self,
@@ -820,33 +832,36 @@ async def _create_background_response(
820832
# Enqueue work item for background workers. Raises QueueFull if at capacity.
821833
try:
822834
self._background_queue.put_nowait(
823-
dict(
824-
response_id=response_id,
825-
input=input,
826-
model=model,
827-
prompt=prompt,
828-
instructions=instructions,
829-
previous_response_id=previous_response_id,
830-
conversation=conversation,
831-
store=store,
832-
temperature=temperature,
833-
frequency_penalty=frequency_penalty,
834-
text=text,
835-
tool_choice=tool_choice,
836-
tools=tools,
837-
include=include,
838-
max_infer_iters=max_infer_iters,
839-
guardrail_ids=guardrail_ids,
840-
parallel_tool_calls=parallel_tool_calls,
841-
max_tool_calls=max_tool_calls,
842-
reasoning=reasoning,
843-
max_output_tokens=max_output_tokens,
844-
safety_identifier=safety_identifier,
845-
service_tier=service_tier,
846-
metadata=metadata,
847-
truncation=truncation,
848-
presence_penalty=presence_penalty,
849-
extra_body=extra_body,
835+
_BackgroundWorkItem(
836+
otel_context=capture_otel_context(),
837+
kwargs=dict(
838+
response_id=response_id,
839+
input=input,
840+
model=model,
841+
prompt=prompt,
842+
instructions=instructions,
843+
previous_response_id=previous_response_id,
844+
conversation=conversation,
845+
store=store,
846+
temperature=temperature,
847+
frequency_penalty=frequency_penalty,
848+
text=text,
849+
tool_choice=tool_choice,
850+
tools=tools,
851+
include=include,
852+
max_infer_iters=max_infer_iters,
853+
guardrail_ids=guardrail_ids,
854+
parallel_tool_calls=parallel_tool_calls,
855+
max_tool_calls=max_tool_calls,
856+
reasoning=reasoning,
857+
max_output_tokens=max_output_tokens,
858+
safety_identifier=safety_identifier,
859+
service_tier=service_tier,
860+
metadata=metadata,
861+
truncation=truncation,
862+
presence_penalty=presence_penalty,
863+
extra_body=extra_body,
864+
),
850865
)
851866
)
852867
except asyncio.QueueFull:

src/llama_stack/providers/utils/inference/inference_store.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66
import asyncio
7-
from typing import Any
7+
from typing import Any, NamedTuple
88

9+
from opentelemetry import context as otel_context
910
from sqlalchemy.exc import IntegrityError
1011

1112
from llama_stack.core.datatypes import AccessRule
1213
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
1314
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
1415
from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
16+
from llama_stack.core.task import activate_otel_context, capture_otel_context, create_task_with_detached_otel_context
1517
from llama_stack.log import get_logger
1618
from llama_stack_api import (
1719
ListOpenAIChatCompletionResponse,
@@ -25,6 +27,12 @@
2527
logger = get_logger(name=__name__, category="inference")
2628

2729

30+
class _WriteItem(NamedTuple):
31+
completion: OpenAIChatCompletion
32+
messages: list[OpenAIMessageParam]
33+
otel_context: otel_context.Context
34+
35+
2836
class InferenceStore:
2937
def __init__(
3038
self,
@@ -37,7 +45,7 @@ def __init__(
3745
self.enable_write_queue = True
3846

3947
# Async write queue and worker control
40-
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
48+
self._queue: asyncio.Queue[_WriteItem] | None = None
4149
self._worker_tasks: list[asyncio.Task[Any]] = []
4250
self._max_write_queue_size: int = reference.max_write_queue_size
4351
self._num_writers: int = max(1, reference.num_writers)
@@ -98,9 +106,8 @@ async def _ensure_workers_started(self) -> None:
98106
)
99107

100108
if not self._worker_tasks:
101-
loop = asyncio.get_running_loop()
102109
for _ in range(self._num_writers):
103-
task = loop.create_task(self._worker_loop())
110+
task = create_task_with_detached_otel_context(self._worker_loop())
104111
self._worker_tasks.append(task)
105112

106113
async def store_chat_completion(
@@ -110,13 +117,14 @@ async def store_chat_completion(
110117
await self._ensure_workers_started()
111118
if self._queue is None:
112119
raise ValueError("Inference store is not initialized")
120+
item = _WriteItem(chat_completion, input_messages, capture_otel_context())
113121
try:
114-
self._queue.put_nowait((chat_completion, input_messages))
122+
self._queue.put_nowait(item)
115123
except asyncio.QueueFull:
116124
logger.warning(
117125
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
118126
)
119-
await self._queue.put((chat_completion, input_messages))
127+
await self._queue.put(item)
120128
else:
121129
await self._write_chat_completion(chat_completion, input_messages)
122130

@@ -127,9 +135,9 @@ async def _worker_loop(self) -> None:
127135
item = await self._queue.get()
128136
except asyncio.CancelledError:
129137
break
130-
chat_completion, input_messages = item
131138
try:
132-
await self._write_chat_completion(chat_completion, input_messages)
139+
with activate_otel_context(item.otel_context):
140+
await self._write_chat_completion(item.completion, item.messages)
133141
except Exception as e: # noqa: BLE001
134142
logger.error(f"Error writing chat completion: {e}")
135143
finally:

0 commit comments

Comments
 (0)