Skip to content

Commit 75a69cd

Browse files
feat: add vLLM V1 PD disagg example (ai-dynamo#1013)
1 parent 4fd4d53 commit 75a69cd

File tree

10 files changed

+777
-0
lines changed

10 files changed

+777
-0
lines changed

examples/vllm_v1/README.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
<!--
2+
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
SPDX-License-Identifier: Apache-2.0
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
-->
17+
18+
# vLLM Deployment Examples
19+
20+
This directory contains examples for deploying vLLM models in both aggregated and disaggregated configurations.
21+
22+
## Prerequisites
23+
24+
1. Install vLLM:
25+
```bash
26+
# Note: Currently requires installation from main branch
27+
# From vLLM 0.8.6 onwards, you can install directly from wheel
28+
git clone https://github.com/vllm-project/vllm.git
29+
VLLM_USE_PRECOMPILED=1 uv pip install --editable ./vllm/
30+
```
31+
32+
2. Start required services:
33+
```bash
34+
docker compose -f deploy/metrics/docker-compose.yml up -d
35+
```
36+
37+
## Running the Server
38+
39+
### Aggregated Deployment
40+
```bash
41+
cd examples/vllm_v1
42+
dynamo serve graphs.agg:Frontend -f configs/agg.yaml
43+
```
44+
45+
### Disaggregated Deployment
46+
```bash
47+
cd examples/vllm_v1
48+
dynamo serve graphs.disagg:Frontend -f configs/disagg.yaml
49+
```
50+
51+
## Testing the API
52+
53+
Send a test request using curl:
54+
```bash
55+
curl localhost:8000/v1/completions \
56+
-H "Content-Type: application/json" \
57+
-d '{
58+
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
59+
"prompt": "In the heart of Eldoria...",
60+
"stream": false,
61+
"max_tokens": 30
62+
}'
63+
```
64+
65+
For more detailed explenations, refer to the main [LLM examples README](../llm/README.md).
66+
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import subprocess
18+
from pathlib import Path
19+
20+
from components.simple_load_balancer import SimpleLoadBalancer
21+
from fastapi import FastAPI
22+
from pydantic import BaseModel
23+
24+
import dynamo.sdk as sdk
25+
from dynamo.sdk import depends, service
26+
from dynamo.sdk.lib.config import ServiceConfig
27+
from dynamo.sdk.lib.image import DYNAMO_IMAGE
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
def get_dynamo_run_binary():
33+
"""Find the dynamo-run binary path in SDK or fallback to 'dynamo-run' command."""
34+
sdk_path = Path(sdk.__file__)
35+
binary_path = sdk_path.parent / "cli/bin/dynamo-run"
36+
if not binary_path.exists():
37+
return "dynamo-run"
38+
else:
39+
return str(binary_path)
40+
41+
42+
class FrontendConfig(BaseModel):
43+
"""Configuration for the Frontend service including model and HTTP server settings."""
44+
45+
served_model_name: str
46+
endpoint: str
47+
port: int = 8080
48+
49+
50+
# TODO: move these to common for all LLMs once we adopt dynamo-run
51+
@service(
52+
dynamo={
53+
"enabled": True,
54+
"namespace": "dynamo",
55+
},
56+
workers=1,
57+
image=DYNAMO_IMAGE,
58+
app=FastAPI(title="LLM Example"),
59+
)
60+
class Frontend:
61+
worker = depends(SimpleLoadBalancer)
62+
63+
def __init__(self):
64+
"""Initialize Frontend service with HTTP server and model configuration."""
65+
config = ServiceConfig.get_instance()
66+
frontend_config = FrontendConfig(**config.get("Frontend", {}))
67+
self.frontend_config = frontend_config
68+
self.process = None
69+
70+
self.start_ingress_and_processor()
71+
72+
def start_ingress_and_processor(self):
73+
"""Starting dynamo-run based ingress and processor"""
74+
logger.info(
75+
f"Starting HTTP server and processor on port {self.frontend_config.port}"
76+
)
77+
dynamo_run_binary = get_dynamo_run_binary()
78+
endpoint = f"dyn://{self.frontend_config.endpoint}"
79+
80+
logger.info(
81+
f"Starting HTTP server and processor on port {self.frontend_config.port}"
82+
)
83+
logger.info(f"Endpoint: {endpoint}")
84+
85+
self.process = subprocess.Popen(
86+
[
87+
dynamo_run_binary,
88+
"in=http",
89+
f"out={endpoint}",
90+
"--http-port",
91+
str(self.frontend_config.port),
92+
],
93+
stdout=None,
94+
stderr=None,
95+
)
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import copy
17+
import logging
18+
import uuid
19+
from typing import AsyncGenerator, Optional
20+
21+
from components.worker import VllmDecodeWorker, VllmPrefillWorker
22+
from utils.args import parse_vllm_args
23+
from utils.protocol import MyRequestOutput, PreprocessedRequest, vLLMGenerateRequest
24+
from vllm.inputs import TokensPrompt
25+
from vllm.sampling_params import SamplingParams
26+
27+
from dynamo.llm import ModelType, register_llm
28+
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
29+
30+
logger = logging.getLogger(__name__)
31+
32+
33+
@service(
34+
dynamo={
35+
"enabled": True,
36+
"namespace": "dynamo",
37+
},
38+
resources={"cpu": "10", "memory": "20Gi"},
39+
workers=1,
40+
)
41+
class SimpleLoadBalancer:
42+
prefill_worker = depends(VllmPrefillWorker)
43+
decode_worker = depends(VllmDecodeWorker)
44+
45+
def __init__(self):
46+
class_name = self.__class__.__name__
47+
self.engine_args = parse_vllm_args(class_name, "")
48+
model_config = self.engine_args.create_model_config()
49+
self.default_sampling_params = model_config.get_diff_sampling_param()
50+
self.enable_disagg = self.engine_args.enable_disagg
51+
52+
@async_on_start
53+
async def async_init(self):
54+
runtime = dynamo_context["runtime"]
55+
logger.info("Registering LLM for discovery")
56+
comp_ns, comp_name = SimpleLoadBalancer.dynamo_address() # type: ignore
57+
endpoint_name = "generate"
58+
for served_model_name in self.engine_args.served_model_name:
59+
logger.info(
60+
f"Registering endpoint {endpoint_name} with model {self.engine_args.model} and served_model_name {served_model_name}"
61+
)
62+
endpoint = (
63+
runtime.namespace(comp_ns).component(comp_name).endpoint(endpoint_name)
64+
)
65+
await register_llm(
66+
ModelType.Backend,
67+
endpoint,
68+
self.engine_args.model,
69+
served_model_name,
70+
)
71+
72+
comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore
73+
self.decode_worker_client = (
74+
await runtime.namespace(comp_ns)
75+
.component(comp_name)
76+
.endpoint("generate")
77+
.client()
78+
)
79+
80+
comp_ns, comp_name = VllmPrefillWorker.dynamo_address() # type: ignore
81+
self.prefill_worker_client = (
82+
await runtime.namespace(comp_ns)
83+
.component(comp_name)
84+
.endpoint("generate")
85+
.client()
86+
)
87+
88+
logger.info("SimpleLoadBalancer has been initialized")
89+
90+
async def send_request_to_prefill(
91+
self, request: vLLMGenerateRequest
92+
) -> MyRequestOutput:
93+
logger.debug("Sending request to prefill")
94+
95+
prefill_request = copy.deepcopy(request)
96+
extra_args = prefill_request.sampling_params.extra_args or {}
97+
extra_args["kv_transfer_params"] = {
98+
"do_remote_decode": True,
99+
}
100+
prefill_request.sampling_params.extra_args = extra_args
101+
prefill_request.sampling_params.max_tokens = 1
102+
prefill_request.sampling_params.min_tokens = 1
103+
104+
logger.debug("Prefill request: %s", prefill_request.model_dump_json())
105+
106+
async for prefill_response in await self.prefill_worker_client.round_robin(
107+
prefill_request.model_dump_json()
108+
):
109+
return MyRequestOutput.model_validate_json(prefill_response.data())
110+
111+
async def send_request_to_decode(
112+
self,
113+
request: vLLMGenerateRequest,
114+
prefill_response: Optional[MyRequestOutput] = None,
115+
) -> AsyncGenerator[MyRequestOutput, None]:
116+
logger.debug("Sending request to decode")
117+
118+
decode_request = copy.deepcopy(request)
119+
120+
if prefill_response:
121+
extra_args = decode_request.sampling_params.extra_args or {}
122+
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
123+
decode_request.sampling_params.extra_args = extra_args
124+
125+
logger.debug("Decode request: %s", decode_request.model_dump_json())
126+
127+
async for decode_response in await self.decode_worker_client.round_robin(
128+
decode_request.model_dump_json()
129+
):
130+
yield MyRequestOutput.model_validate_json(decode_response.data())
131+
132+
@dynamo_endpoint()
133+
async def generate(self, request: PreprocessedRequest):
134+
logger.debug(
135+
"Processor received completion request: %s", request.model_dump_json()
136+
)
137+
138+
vllm_request = self._create_vllm_request(request)
139+
140+
logger.debug("VLLM request: %s", vllm_request.model_dump_json())
141+
142+
if self.enable_disagg:
143+
prefill_response = await self.send_request_to_prefill(vllm_request)
144+
145+
logger.debug("Prefill response: %s", prefill_response.model_dump_json())
146+
else:
147+
prefill_response = None
148+
149+
gen = self.send_request_to_decode(vllm_request, prefill_response)
150+
async for res in self._stream_response(gen):
151+
yield res
152+
153+
def _create_vllm_request(self, request: PreprocessedRequest) -> vLLMGenerateRequest:
154+
request_id = str(uuid.uuid4().hex)
155+
156+
prompt = TokensPrompt(prompt_token_ids=request.token_ids)
157+
158+
sampling_params = SamplingParams(**self.default_sampling_params)
159+
for key, value in request.sampling_options.model_dump().items():
160+
if not value:
161+
continue
162+
if hasattr(sampling_params, key):
163+
setattr(sampling_params, key, value)
164+
165+
max_tokens = request.stop_conditions.max_tokens
166+
if max_tokens:
167+
sampling_params.max_tokens = max_tokens
168+
169+
return vLLMGenerateRequest(
170+
prompt=prompt,
171+
sampling_params=sampling_params,
172+
request_id=request_id,
173+
)
174+
175+
async def _stream_response(self, gen: AsyncGenerator[MyRequestOutput, None]):
176+
num_output_tokens_so_far = 0
177+
async for res in gen:
178+
logger.debug("Decode response: %s", res.model_dump_json())
179+
# res is our MyRequestOutput
180+
181+
# This is the expected way for a request to end.
182+
# The new token ID will be eos, don't forward it.
183+
if res.finished:
184+
yield {"finish_reason": "stop", "token_ids": []}
185+
break
186+
187+
if not res.outputs:
188+
yield {"finish_reason": "error", "token_ids": []}
189+
break
190+
191+
output = res.outputs[0]
192+
next_total_toks = len(output.token_ids)
193+
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
194+
if output.finish_reason:
195+
out["finish_reason"] = output.finish_reason
196+
if output.stop_reason:
197+
out["stop_reason"] = output.stop_reason
198+
yield out
199+
num_output_tokens_so_far = next_total_toks

0 commit comments

Comments
 (0)