Skip to content

Commit 52090e2

Browse files
authored
feat: refactor vllm multimodal example (#3634)
Signed-off-by: ayushag <[email protected]>
1 parent 7d78fda commit 52090e2

File tree

16 files changed

+1855
-12
lines changed

16 files changed

+1855
-12
lines changed

examples/multimodal/launch/agg.sh renamed to components/backends/vllm/launch/agg_multimodal.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,18 @@ fi
5555
# run ingress
5656
python -m dynamo.frontend --http-port=8000 &
5757

58-
# run processor
59-
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
60-
6158
# To make Qwen2.5-VL fit in A100 40GB, set the following extra arguments
6259
EXTRA_ARGS=""
6360
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
6461
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
6562
fi
6663

64+
# run processor
65+
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
66+
6767
# run E/P/D workers
68-
CUDA_VISIBLE_DEVICES=0 python3 components/encode_worker.py --model $MODEL_NAME &
69-
CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill $EXTRA_ARGS &
68+
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --model $MODEL_NAME &
69+
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --model $MODEL_NAME $EXTRA_ARGS &
7070

7171
# Wait for all background processes to complete
7272
wait

components/src/dynamo/vllm/args.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ class Config:
6565
tool_call_parser: Optional[str] = None
6666
reasoning_parser: Optional[str] = None
6767

68+
# multimodal options
69+
multimodal_processor: bool = False
70+
multimodal_encode_worker: bool = False
71+
multimodal_worker: bool = False
72+
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
6873
# dump config to file
6974
dump_config_to: Optional[str] = None
7075

@@ -137,6 +142,34 @@ def parse_args() -> Config:
137142
default=None,
138143
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
139144
)
145+
parser.add_argument(
146+
"--multimodal-processor",
147+
action="store_true",
148+
help="Run as multimodal processor component for handling multimodal requests",
149+
)
150+
parser.add_argument(
151+
"--multimodal-encode-worker",
152+
action="store_true",
153+
help="Run as multimodal encode worker component for processing images/videos",
154+
)
155+
parser.add_argument(
156+
"--multimodal-worker",
157+
action="store_true",
158+
help="Run as multimodal worker component for LLM inference with multimodal data",
159+
)
160+
parser.add_argument(
161+
"--mm-prompt-template",
162+
type=str,
163+
default="USER: <image>\n<prompt> ASSISTANT:",
164+
help=(
165+
"Different multi-modal models expect the prompt to contain different special media prompts. "
166+
"The processor will use this argument to construct the final prompt. "
167+
"User prompt will replace '<prompt>' in the provided template. "
168+
"For example, if the user prompt is 'please describe the image' and the prompt template is "
169+
"'USER: <image> <prompt> ASSISTANT:', the resulting prompt is "
170+
"'USER: <image> please describe the image ASSISTANT:'."
171+
),
172+
)
140173
add_config_dump_args(parser)
141174

142175
parser = AsyncEngineArgs.add_cli_args(parser)
@@ -161,8 +194,35 @@ def parse_args() -> Config:
161194
config.served_model_name = None
162195

163196
config.namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
164-
config.component = "prefill" if args.is_prefill_worker else "backend"
165-
config.endpoint = "generate"
197+
198+
# Check multimodal role exclusivity
199+
mm_flags = (
200+
int(bool(args.multimodal_processor))
201+
+ int(bool(args.multimodal_encode_worker))
202+
+ int(bool(args.multimodal_worker))
203+
)
204+
if mm_flags > 1:
205+
raise ValueError(
206+
"Use only one of --multimodal-processor, --multimodal-encode-worker, or --multimodal-worker"
207+
)
208+
209+
# Set component and endpoint based on worker type
210+
if args.multimodal_processor:
211+
config.component = "processor"
212+
config.endpoint = "generate"
213+
elif args.multimodal_encode_worker:
214+
config.component = "encoder"
215+
config.endpoint = "generate"
216+
elif args.multimodal_worker and args.is_prefill_worker:
217+
config.component = "prefill"
218+
config.endpoint = "generate"
219+
elif args.is_prefill_worker:
220+
config.component = "prefill"
221+
config.endpoint = "generate"
222+
else:
223+
config.component = "backend"
224+
config.endpoint = "generate"
225+
166226
config.engine_args = engine_args
167227
config.is_prefill_worker = args.is_prefill_worker
168228
config.is_decode_worker = args.is_decode_worker
@@ -173,6 +233,10 @@ def parse_args() -> Config:
173233
config.tool_call_parser = args.dyn_tool_call_parser
174234
config.reasoning_parser = args.dyn_reasoning_parser
175235
config.custom_jinja_template = args.custom_jinja_template
236+
config.multimodal_processor = args.multimodal_processor
237+
config.multimodal_encode_worker = args.multimodal_encode_worker
238+
config.multimodal_worker = args.multimodal_worker
239+
config.mm_prompt_template = args.mm_prompt_template
176240

177241
# Validate custom Jinja template file exists if provided
178242
if config.custom_jinja_template is not None:

components/src/dynamo/vllm/main.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
)
2727
from dynamo.runtime import DistributedRuntime, dynamo_worker
2828
from dynamo.runtime.logging import configure_dynamo_logging
29+
from dynamo.vllm.multimodal_handlers import (
30+
EncodeWorkerHandler,
31+
MultimodalPDWorkerHandler,
32+
ProcessorHandler,
33+
)
2934

3035
from .args import ENABLE_LMCACHE, Config, configure_ports, overwrite_args, parse_args
3136
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
@@ -92,7 +97,17 @@ def signal_handler():
9297
if not os.path.exists(config.model):
9398
config.model = config.engine_args.model = await fetch_llm(config.model)
9499

95-
if config.is_prefill_worker:
100+
# Route to appropriate initialization based on config flags
101+
if config.multimodal_processor:
102+
await init_multimodal_processor(runtime, config)
103+
logger.debug("init_multimodal_processor completed")
104+
elif config.multimodal_encode_worker:
105+
await init_multimodal_encode_worker(runtime, config)
106+
logger.debug("init_multimodal_encode_worker completed")
107+
elif config.multimodal_worker:
108+
await init_multimodal_worker(runtime, config)
109+
logger.debug("init_multimodal_worker completed")
110+
elif config.is_prefill_worker:
96111
await init_prefill(runtime, config)
97112
logger.debug("init_prefill completed")
98113
else:
@@ -430,6 +445,147 @@ def get_engine_cache_info(engine: AsyncLLM):
430445
raise
431446

432447

448+
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
449+
"""Initialize multimodal processor component"""
450+
component = runtime.namespace(config.namespace).component(config.component)
451+
await component.create_service()
452+
453+
generate_endpoint = component.endpoint(config.endpoint)
454+
455+
# Get encode worker client
456+
encode_worker_client = (
457+
await runtime.namespace(config.namespace)
458+
.component("encoder")
459+
.endpoint("generate")
460+
.client()
461+
)
462+
463+
# Get prompt template from args (must be passed via environment or command line)
464+
mm_prompt_template = config.mm_prompt_template
465+
466+
handler = ProcessorHandler(
467+
config.engine_args,
468+
encode_worker_client,
469+
mm_prompt_template,
470+
)
471+
472+
logger.info("Waiting for Encoder Worker Instances ...")
473+
await encode_worker_client.wait_for_instances()
474+
475+
# Register the endpoint as entrypoint to a model
476+
await register_llm(
477+
ModelInput.Text, # Custom processor is used and this type bypasses SDK processor
478+
ModelType.Chat,
479+
generate_endpoint,
480+
config.model,
481+
config.served_model_name,
482+
kv_cache_block_size=config.engine_args.block_size,
483+
)
484+
485+
logger.info("Starting to serve the processor endpoint...")
486+
487+
try:
488+
await asyncio.gather(
489+
generate_endpoint.serve_endpoint(
490+
handler.generate, metrics_labels=[("model", config.model)]
491+
),
492+
)
493+
except Exception as e:
494+
logger.error(f"Failed to serve endpoints: {e}")
495+
raise
496+
finally:
497+
handler.cleanup()
498+
499+
500+
async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Config):
501+
"""Initialize multimodal encode worker component"""
502+
component = runtime.namespace(config.namespace).component(config.component)
503+
await component.create_service()
504+
505+
generate_endpoint = component.endpoint(config.endpoint)
506+
507+
# Get PD worker client
508+
# In multimodal mode, the PD worker always registers as "backend"
509+
# (even in disaggregated mode with prefill/decode split, we still connect to "backend")
510+
pd_worker_client = (
511+
await runtime.namespace(config.namespace)
512+
.component("backend")
513+
.endpoint("generate")
514+
.client()
515+
)
516+
517+
handler = EncodeWorkerHandler(
518+
config.engine_args,
519+
pd_worker_client,
520+
)
521+
await handler.async_init(runtime)
522+
logger.info("Waiting for PD Worker Instances ...")
523+
await pd_worker_client.wait_for_instances()
524+
logger.info("Starting to serve the encode worker endpoint...")
525+
526+
try:
527+
await asyncio.gather(
528+
generate_endpoint.serve_endpoint(
529+
handler.generate, metrics_labels=[("model", config.model)]
530+
),
531+
)
532+
except Exception as e:
533+
logger.error(f"Failed to serve endpoints: {e}")
534+
raise
535+
finally:
536+
handler.cleanup()
537+
538+
539+
async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
540+
"""Initialize multimodal worker component for aggregated or disaggregated mode"""
541+
542+
component = runtime.namespace(config.namespace).component(config.component)
543+
await component.create_service()
544+
545+
generate_endpoint = component.endpoint(config.endpoint)
546+
clear_endpoint = component.endpoint("clear_kv_blocks")
547+
548+
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
549+
550+
# TODO: Support Disaggregated mode separately
551+
client = (
552+
await runtime.namespace(config.namespace)
553+
.component("backend")
554+
.endpoint("generate")
555+
.client()
556+
)
557+
558+
handler = MultimodalPDWorkerHandler(
559+
runtime, component, engine_client, config, client
560+
)
561+
562+
await handler.async_init(runtime)
563+
564+
# Set up KV event publisher for prefix caching if enabled
565+
kv_publisher = setup_kv_event_publisher(
566+
config, component, generate_endpoint, vllm_config
567+
)
568+
if kv_publisher:
569+
handler.kv_publisher = kv_publisher
570+
571+
metrics_labels = [("model", config.model)]
572+
573+
try:
574+
await asyncio.gather(
575+
generate_endpoint.serve_endpoint(
576+
handler.generate, metrics_labels=metrics_labels
577+
),
578+
clear_endpoint.serve_endpoint(
579+
handler.clear_kv_blocks, metrics_labels=metrics_labels
580+
),
581+
)
582+
except Exception as e:
583+
logger.error(f"Failed to serve endpoints: {e}")
584+
raise
585+
finally:
586+
handler.cleanup()
587+
588+
433589
def main():
434590
uvloop.run(worker())
435591

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from dynamo.vllm.multimodal_handlers.encode_worker_handler import EncodeWorkerHandler
5+
from dynamo.vllm.multimodal_handlers.processor_handler import ProcessorHandler
6+
from dynamo.vllm.multimodal_handlers.worker_handler import (
7+
MultimodalDecodeWorkerHandler,
8+
MultimodalPDWorkerHandler,
9+
)
10+
11+
__all__ = [
12+
"EncodeWorkerHandler",
13+
"ProcessorHandler",
14+
"MultimodalPDWorkerHandler",
15+
"MultimodalDecodeWorkerHandler",
16+
]

0 commit comments

Comments
 (0)