Skip to content

Commit b73c571

Browse files
feat: Add base64 and HTTP image URL support to vLLM workers (#4114)
Signed-off-by: Krishnan Prashanth <[email protected]>
1 parent 1a9aeab commit b73c571

File tree

5 files changed

+217
-41
lines changed

5 files changed

+217
-41
lines changed

components/src/dynamo/vllm/handlers.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
from abc import ABC, abstractmethod
88
from contextlib import asynccontextmanager
9-
from typing import Any, AsyncGenerator, Dict
9+
from typing import Any, AsyncGenerator, Dict, Final
1010

1111
from vllm.inputs import TokensPrompt
1212
from vllm.sampling_params import SamplingParams
@@ -16,6 +16,13 @@
1616
from dynamo.runtime.logging import configure_dynamo_logging
1717

1818
from .engine_monitor import VllmEngineMonitor
19+
from .multimodal_utils.image_loader import ImageLoader
20+
21+
# Multimodal data dictionary keys
22+
IMAGE_URL_KEY: Final = "image_url"
23+
VIDEO_URL_KEY: Final = "video_url"
24+
URL_VARIANT_KEY: Final = "Url"
25+
DECODED_VARIANT_KEY: Final = "Decoded"
1926

2027
configure_dynamo_logging()
2128
logger = logging.getLogger(__name__)
@@ -65,6 +72,7 @@ def __init__(self, runtime, component, engine, default_sampling_params):
6572
self.default_sampling_params = default_sampling_params
6673
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
6774
self.engine_monitor = VllmEngineMonitor(runtime, engine)
75+
self.image_loader = ImageLoader()
6876

6977
@abstractmethod
7078
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
@@ -111,6 +119,50 @@ def cleanup(self):
111119
"""Override in subclasses if cleanup is needed."""
112120
pass
113121

122+
async def _extract_multimodal_data(
123+
self, request: Dict[str, Any]
124+
) -> Dict[str, Any] | None:
125+
"""
126+
Extract and decode multimodal data from PreprocessedRequest.
127+
"""
128+
if "multi_modal_data" not in request or request["multi_modal_data"] is None:
129+
return None
130+
131+
mm_map = request["multi_modal_data"]
132+
vllm_mm_data = {}
133+
134+
# Process image_url entries
135+
images = []
136+
for item in mm_map.get(IMAGE_URL_KEY, []):
137+
if isinstance(item, dict) and URL_VARIANT_KEY in item:
138+
url = item[URL_VARIANT_KEY]
139+
try:
140+
# ImageLoader supports both data: and http(s): URLs with caching
141+
image = await self.image_loader.load_image(url)
142+
images.append(image)
143+
logger.debug(f"Loaded image from URL: {url[:80]}...")
144+
except Exception:
145+
logger.exception(f"Failed to load image from {url[:80]}...")
146+
raise
147+
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
148+
# Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer)
149+
# Will contain NIXL metadata for direct memory access
150+
# TODO: Implement NIXL read when PRs merge
151+
logger.warning(
152+
"Decoded multimodal data not yet supported in standard worker"
153+
)
154+
155+
if images:
156+
# vLLM expects single image or list
157+
vllm_mm_data["image"] = images[0] if len(images) == 1 else images
158+
logger.debug(f"Extracted {len(images)} image(s) for multimodal processing")
159+
160+
# Handle video_url entries (future expansion)
161+
if VIDEO_URL_KEY in mm_map:
162+
logger.warning("Video multimodal data not yet supported in standard worker")
163+
164+
return vllm_mm_data if vllm_mm_data else None
165+
114166
async def generate_tokens(
115167
self, prompt, sampling_params, request_id, data_parallel_rank=None
116168
):
@@ -168,7 +220,12 @@ async def generate(self, request, context):
168220
request_id = context.id()
169221
logger.debug(f"Decode Request ID: {request_id}")
170222

171-
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
223+
# Extract and decode multimodal data if present
224+
multi_modal_data = await self._extract_multimodal_data(request)
225+
226+
prompt = TokensPrompt(
227+
prompt_token_ids=request["token_ids"], multi_modal_data=multi_modal_data
228+
)
172229

173230
# Build sampling params from request
174231
sampling_params = build_sampling_params(request, self.default_sampling_params)
@@ -210,8 +267,13 @@ async def generate(self, request, context):
210267
request_id = context.id()
211268
logger.debug(f"Prefill Request ID: {request_id}")
212269

270+
# Extract and decode multimodal data if present
271+
multi_modal_data = await self._extract_multimodal_data(request)
272+
213273
token_ids = request["token_ids"]
214-
prompt = TokensPrompt(prompt_token_ids=token_ids)
274+
prompt = TokensPrompt(
275+
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
276+
)
215277

216278
# Build sampling params from request using shared utility
217279
sampling_params = build_sampling_params(request, self.default_sampling_params)
Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
#!/bin/bash
22
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Aggregated multimodal serving with standard Dynamo preprocessing
6+
#
7+
# Architecture: Single-worker PD (Prefill-Decode)
8+
# - Frontend: Rust OpenAIPreprocessor handles image URLs (HTTP and data:// base64)
9+
# - Worker: Standard vLLM worker with vision model support
10+
#
11+
# For EPD (Encode-Prefill-Decode) architecture with dedicated encoding worker,
12+
# see agg_multimodal_epd.sh
13+
414
set -e
515
trap 'echo Cleaning up...; kill 0' EXIT
616

717
# Default values
8-
MODEL_NAME="llava-hf/llava-1.5-7b-hf"
9-
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
10-
PROVIDED_PROMPT_TEMPLATE=""
18+
MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct"
1119

1220
# Parse command line arguments
1321
while [[ $# -gt 0 ]]; do
@@ -16,15 +24,10 @@ while [[ $# -gt 0 ]]; do
1624
MODEL_NAME=$2
1725
shift 2
1826
;;
19-
--prompt-template)
20-
PROVIDED_PROMPT_TEMPLATE=$2
21-
shift 2
22-
;;
2327
-h|--help)
2428
echo "Usage: $0 [OPTIONS]"
2529
echo "Options:"
26-
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
27-
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
30+
echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)"
2831
echo " -h, --help Show this help message"
2932
exit 0
3033
;;
@@ -36,37 +39,23 @@ while [[ $# -gt 0 ]]; do
3639
esac
3740
done
3841

39-
# Set PROMPT_TEMPLATE based on the MODEL_NAME
40-
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
41-
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
42-
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
43-
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
44-
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
45-
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
46-
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
47-
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
48-
else
49-
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
50-
echo "Please provide a prompt template using --prompt-template option."
51-
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
52-
exit 1
53-
fi
54-
55-
# run ingress
42+
# Start frontend with Rust OpenAIPreprocessor
5643
python -m dynamo.frontend --http-port=8000 &
5744

58-
# To make Qwen2.5-VL fit in A100 40GB, set the following extra arguments
45+
# Configure GPU memory optimization for specific models
5946
EXTRA_ARGS=""
6047
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
6148
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
6249
fi
6350

64-
# run processor
65-
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
66-
67-
# run E/P/D workers
68-
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --model $MODEL_NAME &
69-
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --model $MODEL_NAME $EXTRA_ARGS &
51+
# Start vLLM worker with vision model
52+
# Multimodal data (images) are decoded in the backend worker using ImageLoader
53+
# --enforce-eager: Quick deployment (remove for production)
54+
# --connector none: No KV transfer needed for aggregated serving
55+
DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=8081 \
56+
python -m dynamo.vllm --model $MODEL_NAME --enforce-eager --connector none $EXTRA_ARGS
7057

7158
# Wait for all background processes to complete
7259
wait
60+
61+
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/bin/bash
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# EPD (Encode-Prefill-Decode) multimodal deployment
6+
#
7+
# Architecture: 3-component disaggregation
8+
# - Processor: Python-based preprocessor (bypasses Rust OpenAIPreprocessor)
9+
# - Encode Worker: Dedicated vision encoder that extracts image embeddings
10+
# - PD Worker: Standard prefill/decode worker that receives embeddings via NIXL
11+
#
12+
# Benefits: Decouples encoding from inference, enables independent scaling
13+
# For standard single-worker deployment, see agg_multimodal.sh
14+
15+
set -e
16+
trap 'echo Cleaning up...; kill 0' EXIT
17+
18+
# Default values
19+
MODEL_NAME="llava-hf/llava-1.5-7b-hf"
20+
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
21+
PROVIDED_PROMPT_TEMPLATE=""
22+
23+
# Parse command line arguments
24+
while [[ $# -gt 0 ]]; do
25+
case $1 in
26+
--model)
27+
MODEL_NAME=$2
28+
shift 2
29+
;;
30+
--prompt-template)
31+
PROVIDED_PROMPT_TEMPLATE=$2
32+
shift 2
33+
;;
34+
-h|--help)
35+
echo "Usage: $0 [OPTIONS]"
36+
echo "Options:"
37+
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
38+
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
39+
echo " -h, --help Show this help message"
40+
exit 0
41+
;;
42+
*)
43+
echo "Unknown option: $1"
44+
echo "Use --help for usage information"
45+
exit 1
46+
;;
47+
esac
48+
done
49+
50+
# Set PROMPT_TEMPLATE based on the MODEL_NAME
51+
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
52+
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
53+
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
54+
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
55+
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
56+
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
57+
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
58+
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
59+
else
60+
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
61+
echo "Please provide a prompt template using --prompt-template option."
62+
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
63+
exit 1
64+
fi
65+
66+
# Start frontend (HTTP endpoint)
67+
python -m dynamo.frontend --http-port=8000 &
68+
69+
# To make Qwen2.5-VL fit in A100 40GB, set the following extra arguments
70+
EXTRA_ARGS=""
71+
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
72+
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
73+
fi
74+
75+
# Start processor (Python-based preprocessing, handles prompt templating)
76+
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
77+
78+
# run E/P/D workers
79+
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --model $MODEL_NAME &
80+
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --model $MODEL_NAME $EXTRA_ARGS &
81+
82+
# Wait for all background processes to complete
83+
wait

lib/bindings/python/Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/serve/test_vllm.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,10 @@ class VLLMConfig(EngineConfig):
104104
completion_payload_default(expected_response=["joke"]),
105105
],
106106
),
107-
"multimodal_agg_llava": VLLMConfig(
108-
name="multimodal_agg_llava",
107+
"multimodal_agg_llava_epd": VLLMConfig(
108+
name="multimodal_agg_llava_epd",
109109
directory=vllm_dir,
110-
script_name="agg_multimodal.sh",
110+
script_name="agg_multimodal_epd.sh",
111111
marks=[pytest.mark.gpu_2],
112112
model="llava-hf/llava-1.5-7b-hf",
113113
script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
@@ -128,16 +128,42 @@ class VLLMConfig(EngineConfig):
128128
)
129129
],
130130
),
131+
"multimodal_agg_qwen_epd": VLLMConfig(
132+
name="multimodal_agg_qwen_epd",
133+
directory=vllm_dir,
134+
script_name="agg_multimodal_epd.sh",
135+
marks=[pytest.mark.gpu_2],
136+
model="Qwen/Qwen2.5-VL-7B-Instruct",
137+
delayed_start=0,
138+
script_args=["--model", "Qwen/Qwen2.5-VL-7B-Instruct"],
139+
timeout=360,
140+
request_payloads=[
141+
chat_payload(
142+
[
143+
{"type": "text", "text": "What is in this image?"},
144+
{
145+
"type": "image_url",
146+
"image_url": {
147+
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
148+
},
149+
},
150+
],
151+
repeat_count=1,
152+
expected_response=["bus"],
153+
)
154+
],
155+
),
131156
"multimodal_agg_qwen": VLLMConfig(
132157
name="multimodal_agg_qwen",
133158
directory=vllm_dir,
134159
script_name="agg_multimodal.sh",
135160
marks=[pytest.mark.gpu_2],
136161
model="Qwen/Qwen2.5-VL-7B-Instruct",
137-
delayed_start=0,
138162
script_args=["--model", "Qwen/Qwen2.5-VL-7B-Instruct"],
163+
delayed_start=0,
139164
timeout=360,
140165
request_payloads=[
166+
# HTTP URL test
141167
chat_payload(
142168
[
143169
{"type": "text", "text": "What is in this image?"},
@@ -150,7 +176,21 @@ class VLLMConfig(EngineConfig):
150176
],
151177
repeat_count=1,
152178
expected_response=["bus"],
153-
)
179+
),
180+
# Base64 data URL test (1x1 PNG inline, avoids network fetch)
181+
chat_payload(
182+
[
183+
{"type": "text", "text": "What do you see in this image?"},
184+
{
185+
"type": "image_url",
186+
"image_url": {
187+
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAAAAAA6fptVAAAACklEQVR4nGNoAAAAggCBd81ytgAAAABJRU5ErkJggg=="
188+
},
189+
},
190+
],
191+
repeat_count=1,
192+
expected_response=[], # Just validate no error
193+
),
154194
],
155195
),
156196
# TODO: Update this test case when we have video multimodal support in vllm official components

0 commit comments

Comments
 (0)