Skip to content

Commit c263a99

Browse files
authored
feat: propagate OTEL trace context across E/P/D multimodal workers (#7239)
1 parent 34f13a1 commit c263a99

File tree

5 files changed

+50
-16
lines changed

5 files changed

+50
-16
lines changed

components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
NixlWriteEmbeddingReceiver,
2222
)
2323
from dynamo.common.utils import nvtx_utils as _nvtx
24+
from dynamo.common.utils.otel_tracing import build_trace_headers
2425
from dynamo.common.utils.time_section import time_and_log_code_section
2526
from dynamo.runtime import Client, DistributedRuntime
2627

@@ -156,7 +157,7 @@ def _parse_frontend_request(
156157
# ── Multimodal data loading ──────────────────────────────────────
157158

158159
async def _load_multimodal_data(
159-
self, image_urls: list[str], request_id: str
160+
self, image_urls: list[str], request_id: str, context=None
160161
) -> dict[str, Any]:
161162
"""Fetch embeddings from encode workers and load into an engine-ready dict.
162163
@@ -174,6 +175,7 @@ async def _load_multimodal_data(
174175
model=self.config.model,
175176
embeddings_dtype=self.EMBEDDINGS_DTYPE,
176177
cache=self.embedding_cache_manager,
178+
context=context,
177179
)
178180

179181
# ── Request metadata finalization ────────────────────────────────
@@ -260,9 +262,11 @@ async def _generate_agg(
260262
request: vLLMMultimodalRequest,
261263
multi_modal_data: dict[str, Any],
262264
rng_ttft=None,
265+
context=None,
263266
):
264267
"""Run prefill and decode on this worker (aggregated mode)."""
265268
lora_request = self._resolve_lora_request(request.model)
269+
trace_headers = build_trace_headers(context) if context else None
266270
gen = self.engine_client.generate(
267271
prompt=TokensPrompt(
268272
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
@@ -271,6 +275,7 @@ async def _generate_agg(
271275
sampling_params=request.sampling_params,
272276
request_id=request.request_id,
273277
lora_request=lora_request,
278+
trace_headers=trace_headers,
274279
)
275280

276281
num_output_tokens_so_far = 0
@@ -302,6 +307,7 @@ async def _generate_disagg(
302307
request: vLLMMultimodalRequest,
303308
multi_modal_data: dict[str, Any],
304309
rng_ttft=None,
310+
context=None,
305311
):
306312
"""Prefill locally, then forward to a remote decode worker."""
307313
with _nvtx.annotate(
@@ -319,6 +325,7 @@ async def _generate_disagg(
319325
logger.debug("Prefill request: %s", prefill_only_request)
320326

321327
lora_request = self._resolve_lora_request(request.model)
328+
trace_headers = build_trace_headers(context) if context else None
322329
gen = self.engine_client.generate(
323330
prompt=TokensPrompt(
324331
prompt_token_ids=prefill_only_request.engine_prompt[
@@ -329,6 +336,7 @@ async def _generate_disagg(
329336
sampling_params=prefill_only_request.sampling_params,
330337
request_id=prefill_only_request.request_id,
331338
lora_request=lora_request,
339+
trace_headers=trace_headers,
332340
)
333341

334342
# Drain prefill generator (max_tokens=1, expect a single response)
@@ -382,7 +390,7 @@ async def _generate_disagg(
382390
async for (
383391
decode_response
384392
) in await self.decode_worker_client.round_robin( # type: ignore
385-
request.model_dump_json()
393+
request.model_dump_json(), context=context
386394
):
387395
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore
388396
yield self._format_engine_output(output, num_output_tokens_so_far)
@@ -406,7 +414,7 @@ async def generate(self, raw_request: dict, context):
406414

407415
rng_load = _nvtx.start_range("mm:pd:load_multimodal", color="yellow")
408416
multi_modal_data = await self._load_multimodal_data(
409-
image_urls, request.request_id
417+
image_urls, request.request_id, context
410418
)
411419
_nvtx.end_range(rng_load)
412420

@@ -415,13 +423,15 @@ async def generate(self, raw_request: dict, context):
415423
if self.enable_disagg and self.decode_worker_client:
416424
rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
417425
async for chunk in self._generate_disagg(
418-
request, multi_modal_data, rng_ttft
426+
request, multi_modal_data, rng_ttft, context=context
419427
):
420428
yield chunk
421429
_nvtx.end_range(rng_disagg)
422430
else:
423431
rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red")
424-
async for chunk in self._generate_agg(request, multi_modal_data, rng_ttft):
432+
async for chunk in self._generate_agg(
433+
request, multi_modal_data, rng_ttft, context=context
434+
):
425435
yield chunk
426436
_nvtx.end_range(rng_agg)
427437

components/src/dynamo/vllm/multimodal_handlers/worker_handler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import dynamo.nixl_connect as connect
99
from dynamo.common.utils import nvtx_utils as _nvtx
10+
from dynamo.common.utils.otel_tracing import build_trace_headers
1011
from dynamo.common.utils.time_section import time_and_log_code_section
1112
from dynamo.runtime import DistributedRuntime
1213

@@ -57,14 +58,14 @@ async def async_init(self, runtime: DistributedRuntime):
5758
async def generate(self, request: vLLMMultimodalRequest, context):
5859
rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue")
5960
logger.debug(f"Got raw request: {request}")
61+
if not isinstance(request, vLLMMultimodalRequest):
62+
if isinstance(request, str):
63+
request = vLLMMultimodalRequest.model_validate_json(request)
64+
else:
65+
request = vLLMMultimodalRequest.model_validate(request)
6066
with time_and_log_code_section(
6167
f"[DECODE] request: {request.request_id} preprocessing time"
6268
):
63-
if not isinstance(request, vLLMMultimodalRequest):
64-
if isinstance(request, str):
65-
request = vLLMMultimodalRequest.model_validate_json(request)
66-
else:
67-
request = vLLMMultimodalRequest.model_validate(request)
6869
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
6970

7071
# For Qwen VL models with mRoPE, we need to pass multi_modal_data containing
@@ -90,6 +91,7 @@ async def generate(self, request: vLLMMultimodalRequest, context):
9091
image_grid_thw, embeddings_shape, request.request_id
9192
)
9293
lora_request = self._resolve_lora_request(request.model)
94+
trace_headers = build_trace_headers(context) if context else None
9395

9496
with time_and_log_code_section(
9597
f"[DECODE] request: {request.request_id} generate time"
@@ -102,6 +104,7 @@ async def generate(self, request: vLLMMultimodalRequest, context):
102104
sampling_params=request.sampling_params,
103105
request_id=request.request_id,
104106
lora_request=lora_request,
107+
trace_headers=trace_headers,
105108
)
106109

107110
rng_first = _nvtx.start_range("mm:decode:first_token", color="darkred")

components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async def _fetch_from_encode_workers(
140140
image_urls: List[str],
141141
request_id: str,
142142
receiver: AbstractEmbeddingReceiver,
143+
context=None,
143144
) -> tuple[List[MultiModalGroup], _PendingRelease | None]:
144145
"""Fan out image URLs to encode workers, load embeddings, and return ready groups.
145146
@@ -176,15 +177,15 @@ async def _fetch_from_encode_workers(
176177
encode_request.multimodal_inputs = batch
177178
payload = encode_request.model_dump_json()
178179
encode_response_streams.append(
179-
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
180+
await encode_worker_client.round_robin(payload, context=context) # type: ignore[arg-type]
180181
)
181182
batch = []
182183

183184
if batch:
184185
encode_request.multimodal_inputs = batch
185186
payload = encode_request.model_dump_json()
186187
encode_response_streams.append(
187-
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
188+
await encode_worker_client.round_robin(payload, context=context) # type: ignore[arg-type]
188189
)
189190

190191
with time_and_log_code_section(
@@ -223,6 +224,7 @@ async def _fetch_embeddings(
223224
request_id: str,
224225
receiver: AbstractEmbeddingReceiver,
225226
cache: MultimodalEmbeddingCacheManager | None = None,
227+
context=None,
226228
) -> tuple[list[MultiModalGroup], _PendingRelease | None]:
227229
"""Fetch multimodal embeddings with transparent cache-through.
228230
@@ -262,6 +264,7 @@ async def _fetch_embeddings(
262264
miss_urls,
263265
request_id,
264266
receiver,
267+
context=context,
265268
)
266269

267270
# ── 3. Update cache (no-op when cache is None) ──────────────
@@ -293,6 +296,7 @@ async def load_multimodal_embeddings(
293296
model: str,
294297
embeddings_dtype: torch.dtype,
295298
cache: MultimodalEmbeddingCacheManager | None = None,
299+
context=None,
296300
) -> Dict[str, Any]:
297301
"""Fetch embeddings and build engine-ready ``multi_modal_data``.
298302
@@ -307,6 +311,7 @@ async def load_multimodal_embeddings(
307311
request_id,
308312
receiver,
309313
cache=cache,
314+
context=context,
310315
)
311316

312317
multi_modal_data: Dict[str, Any] = defaultdict(list)

components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ async def fake_generate(**kwargs):
299299
decode_resp = MagicMock()
300300
decode_resp.data.return_value = decode_json
301301

302-
async def fake_round_robin(payload):
302+
async def fake_round_robin(payload, context=None):
303303
async def _stream():
304304
yield decode_resp
305305

lib/bindings/python/src/dynamo/_core.pyi

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,19 +209,35 @@ class Client:
209209
"""
210210
...
211211

212-
async def random(self, request: JsonLike) -> AsyncIterator[JsonLike]:
212+
async def random(
213+
self,
214+
request: JsonLike,
215+
annotated: bool | None = True,
216+
context: Context | None = None,
217+
) -> AsyncIterator[JsonLike]:
213218
"""
214219
Pick a random instance of the endpoint and issue the request
215220
"""
216221
...
217222

218-
async def round_robin(self, request: JsonLike) -> AsyncIterator[JsonLike]:
223+
async def round_robin(
224+
self,
225+
request: JsonLike,
226+
annotated: bool | None = True,
227+
context: Context | None = None,
228+
) -> AsyncIterator[JsonLike]:
219229
"""
220230
Pick the next instance of the endpoint in a round-robin fashion
221231
"""
222232
...
223233

224-
async def direct(self, request: JsonLike, instance: str) -> AsyncIterator[JsonLike]:
234+
async def direct(
235+
self,
236+
request: JsonLike,
237+
instance_id: int,
238+
annotated: bool | None = True,
239+
context: Context | None = None,
240+
) -> AsyncIterator[JsonLike]:
225241
"""
226242
Pick a specific instance of the endpoint
227243
"""

0 commit comments

Comments
 (0)