Skip to content

Commit 88caaf1

Browse files
authored
[Feature] Split vllm-project#1303 Part 1: PD disaggregation scaffolding (vllm-project#1863)
Signed-off-by: Jinheng Li <ahengljh@gmail.com>
1 parent 6a5fa58 commit 88caaf1

File tree

5 files changed

+723
-0
lines changed

5 files changed

+723
-0
lines changed

docs/configuration/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ For introduction, please check [Introduction for stage config](./stage_configs.m
1414

1515
- **[GPU Memory Calculation and Configuration](./gpu_memory_utilization.md)** - Guide on how to calculate memory requirements and set up `gpu_memory_utilization` for optimal performance
1616

17+
## Multi-Stage Recipes
18+
19+
- **[Prefill-Decode Disaggregation](./pd_disaggregation.md)** - How to derive a PD-aware Qwen3-Omni stage config from the default config without introducing another bundled YAML
20+
1721
## Optimization Features
1822

1923
- **[TeaCache Configuration](../user_guide/diffusion/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Prefill-Decode (PD) Disaggregation
2+
3+
PD disaggregation splits the Qwen3-Omni thinker into separate prefill and decode
4+
stages so prompt processing and token generation can run on different workers.
5+
6+
This is documented as a stage-config recipe instead of a bundled YAML because the
7+
deployment-specific values usually change per environment:
8+
9+
- GPU placement
10+
- `tensor_parallel_size`
11+
- connector backend and connector ports
12+
- connector IPs or bootstrap addresses
13+
14+
Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml)
15+
and copy it to your own file, for example `qwen3_omni_pd.yaml`. Then apply the
16+
changes below.
17+
18+
## Requirements
19+
20+
- 3+ GPUs for a basic layout: prefill, decode, and talker+code2wav
21+
- A KV connector supported by vLLM, such as `MooncakeConnector`
22+
- Matching `tensor_parallel_size` on the prefill and decode thinker stages
23+
24+
## 1. Split the thinker into prefill and decode stages
25+
26+
Replace the original thinker stage with two stages:
27+
28+
```yaml
29+
stage_args:
30+
- stage_id: 0
31+
stage_type: llm
32+
is_prefill_only: true
33+
runtime:
34+
devices: "0"
35+
max_batch_size: 16
36+
engine_args:
37+
model_stage: thinker
38+
model_arch: Qwen3OmniMoeForConditionalGeneration
39+
worker_type: ar
40+
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
41+
gpu_memory_utilization: 0.9
42+
enforce_eager: true
43+
trust_remote_code: true
44+
engine_output_type: latent
45+
distributed_executor_backend: "mp"
46+
enable_prefix_caching: false
47+
max_num_batched_tokens: 32768
48+
hf_config_name: thinker_config
49+
tensor_parallel_size: 1
50+
kv_transfer_config:
51+
kv_connector: "MooncakeConnector"
52+
kv_role: "kv_producer"
53+
kv_rank: 0
54+
kv_parallel_size: 2
55+
kv_connector_extra_config:
56+
mooncake_bootstrap_port: 25201
57+
final_output: false
58+
is_comprehension: true
59+
default_sampling_params:
60+
temperature: 0.4
61+
top_p: 0.9
62+
top_k: 1
63+
max_tokens: 2048
64+
seed: 42
65+
detokenize: True
66+
repetition_penalty: 1.05
67+
68+
- stage_id: 1
69+
stage_type: llm
70+
is_decode_only: true
71+
runtime:
72+
devices: "1"
73+
max_batch_size: 64
74+
engine_args:
75+
model_stage: thinker
76+
model_arch: Qwen3OmniMoeForConditionalGeneration
77+
worker_type: ar
78+
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
79+
gpu_memory_utilization: 0.9
80+
enforce_eager: true
81+
trust_remote_code: true
82+
engine_output_type: latent
83+
distributed_executor_backend: "mp"
84+
enable_prefix_caching: false
85+
max_num_batched_tokens: 32768
86+
hf_config_name: thinker_config
87+
tensor_parallel_size: 1
88+
kv_transfer_config:
89+
kv_connector: "MooncakeConnector"
90+
kv_role: "kv_consumer"
91+
kv_rank: 1
92+
kv_parallel_size: 2
93+
kv_connector_extra_config:
94+
mooncake_bootstrap_port: 25202
95+
engine_input_source: [0]
96+
final_output: true
97+
final_output_type: text
98+
is_comprehension: true
99+
default_sampling_params:
100+
temperature: 0.4
101+
top_p: 0.9
102+
top_k: 1
103+
max_tokens: 2048
104+
seed: 42
105+
detokenize: True
106+
repetition_penalty: 1.05
107+
```
108+
109+
Notes:
110+
111+
- `is_prefill_only: true` marks the thinker stage that only saves KV.
112+
- `is_decode_only: true` marks the thinker stage that resumes from remote KV.
113+
- `kv_transfer_config` is required on both stages.
114+
- The orchestrator forces the prefill stage to run with `max_tokens=1`, so the
115+
prefill side only processes the prompt and exports KV.
116+
117+
## 2. Shift the downstream stages by one index
118+
119+
After inserting the extra thinker stage, renumber the remaining stages:
120+
121+
```yaml
122+
- stage_id: 2
123+
runtime:
124+
devices: "2"
125+
engine_input_source: [1]
126+
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
127+
128+
- stage_id: 3
129+
runtime:
130+
devices: "2"
131+
max_batch_size: 1
132+
engine_input_source: [2]
133+
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
134+
```
135+
136+
Compared with the default Qwen3-Omni config:
137+
138+
- the talker becomes stage `2` instead of stage `1`
139+
- the code2wav stage becomes stage `3` instead of stage `2`
140+
- the talker now reads from decode stage `1`
141+
142+
## 3. Add runtime edges for the four-stage pipeline
143+
144+
```yaml
145+
runtime:
146+
enabled: true
147+
defaults:
148+
window_size: -1
149+
max_inflight: 1
150+
edges:
151+
- from: 0
152+
to: 1
153+
window_size: -1
154+
- from: 1
155+
to: 2
156+
window_size: -1
157+
- from: 2
158+
to: 3
159+
window_size: -1
160+
```
161+
162+
## 4. Launch with your custom config
163+
164+
```bash
165+
vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
166+
--stage-configs-path /path/to/qwen3_omni_pd.yaml
167+
```
168+
169+
## Operational Notes
170+
171+
- `MooncakeConnector` does not support heterogeneous TP sizes across the PD
172+
pair. Keep prefill and decode at the same `tensor_parallel_size`.
173+
- If the thinker requires TP=2, both thinker stages must use TP=2 and be given
174+
separate GPU sets, for example `"0,1"` for prefill and `"2,3"` for decode.
175+
- Choose connector ports and addresses that match your deployment. The values
176+
shown above are examples only.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Patched KV transfer connectors for PD disaggregation.
2+
3+
This package provides monkey-patched versions of vLLM's native KV transfer
4+
connectors (e.g. MooncakeConnector) that fix the request-ID mismatch problem
5+
in prefill-decode disaggregation.
6+
7+
vLLM's ``InputProcessor.assign_request_id()`` appends a random 8-char suffix
8+
to each request ID internally. The prefill engine stores KV under its own
9+
suffix, but the decode engine generates a *different* suffix — so it can never
10+
find the KV data. The patched connector threads the prefill engine's internal
11+
``remote_request_id`` through ``kv_transfer_params`` so the decode side can
12+
reference the correct KV entry.
13+
"""
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""Monkey-patch vLLM's MooncakeConnector to fix request-ID mismatch in PD disaggregation.
2+
3+
vLLM's InputProcessor appends a random suffix to each request ID. The prefill
4+
engine stores KV under its suffix, but the decode engine generates a different
5+
suffix. This patch threads ``remote_request_id`` through ``kv_transfer_params``
6+
so the decode side references the correct KV entry.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import logging
12+
import sys
13+
from dataclasses import dataclass
14+
from typing import Any
15+
16+
logger = logging.getLogger(__name__)
17+
18+
_patched: bool = False
19+
20+
21+
@dataclass
22+
class PatchedRecvReqMeta:
23+
"""Receive-request metadata carrying the prefill engine's request ID."""
24+
25+
request_id: str
26+
remote_request_id: str
27+
local_block_ids: list[int]
28+
kv_transfer_params: dict[str, Any]
29+
30+
31+
def _import_mooncake_module():
32+
"""Import MooncakeConnector module, supporting both vLLM >=0.16 and older."""
33+
try:
34+
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake import mooncake_connector
35+
36+
return mooncake_connector
37+
except ImportError:
38+
pass
39+
try:
40+
from vllm.distributed.kv_transfer.kv_connector.v1 import mooncake_connector
41+
42+
return mooncake_connector
43+
except ImportError:
44+
return None
45+
46+
47+
def _create_patched_mooncake_connector():
48+
"""Return a subclass of MooncakeConnector with remote_request_id support."""
49+
try:
50+
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector import (
51+
MooncakeConnector as _OriginalMooncakeConnector,
52+
)
53+
except (ImportError, AttributeError):
54+
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector import (
55+
MooncakeConnector as _OriginalMooncakeConnector,
56+
)
57+
58+
class PatchedMooncakeConnector(_OriginalMooncakeConnector):
59+
"""Fixes request-ID mismatch in PD disaggregation by injecting
60+
remote_request_id on the prefill side and using it for KV lookup
61+
on the decode side.
62+
"""
63+
64+
def __init__(self, *args: Any, **kwargs: Any) -> None:
65+
super().__init__(*args, **kwargs)
66+
self.remote_to_local_req: dict[str, str] = {}
67+
logger.info("[PatchedMooncakeConnector] Initialized")
68+
69+
def request_finished(
70+
self,
71+
request: Any,
72+
block_ids: list[int],
73+
) -> tuple[bool, dict[str, Any] | None]:
74+
result = super().request_finished(request, block_ids)
75+
76+
if isinstance(result, tuple) and len(result) == 2:
77+
delay_free, kv_params = result
78+
else:
79+
delay_free, kv_params = False, result
80+
81+
# Normalise _reqs_need_send values
82+
req_id = getattr(request, "request_id", None)
83+
if req_id and hasattr(self, "_reqs_need_send"):
84+
entry = self._reqs_need_send.get(req_id)
85+
if isinstance(entry, tuple) and len(entry) == 2:
86+
self._reqs_need_send[req_id] = entry[1]
87+
88+
# Inject remote_request_id into kv_transfer_params
89+
if kv_params is not None and isinstance(kv_params, dict):
90+
kv_params["remote_request_id"] = req_id or "NOT_SET"
91+
if hasattr(self, "side_channel_host"):
92+
kv_params.setdefault("remote_host", self.side_channel_host)
93+
if hasattr(self, "side_channel_port"):
94+
kv_params.setdefault("remote_port", self.side_channel_port)
95+
96+
return delay_free, kv_params
97+
98+
def add_new_req(
99+
self,
100+
request_id: str,
101+
local_block_ids: list[int],
102+
kv_transfer_params: dict[str, Any] | None = None,
103+
**kwargs: Any,
104+
) -> None:
105+
super().add_new_req(request_id, local_block_ids, kv_transfer_params, **kwargs)
106+
107+
kv_transfer_params = kv_transfer_params or {}
108+
load_remote_cache = kv_transfer_params.get(
109+
"do_remote_prefill",
110+
kv_transfer_params.get("load_remote_cache", False),
111+
)
112+
113+
if load_remote_cache:
114+
remote_request_id = kv_transfer_params.get("remote_request_id", request_id)
115+
meta = PatchedRecvReqMeta(
116+
request_id=request_id,
117+
remote_request_id=remote_request_id,
118+
local_block_ids=local_block_ids,
119+
kv_transfer_params=kv_transfer_params,
120+
)
121+
if not hasattr(self, "_reqs_need_recv"):
122+
self._reqs_need_recv = {}
123+
self._reqs_need_recv[request_id] = meta
124+
125+
def group_kv_pull(self, metadata: Any | None = None) -> None:
126+
"""Use remote_request_id as ZMQ lookup key via save-patch-restore."""
127+
if not hasattr(self, "_reqs_need_recv") or not self._reqs_need_recv:
128+
return
129+
130+
original_recv = self._reqs_need_recv.copy()
131+
patched_recv: dict[str, Any] = {}
132+
133+
for local_id, meta in original_recv.items():
134+
if isinstance(meta, PatchedRecvReqMeta):
135+
remote_id = meta.remote_request_id
136+
self.remote_to_local_req[remote_id] = local_id
137+
patched_meta = type(meta)(
138+
request_id=remote_id,
139+
remote_request_id=remote_id,
140+
local_block_ids=meta.local_block_ids,
141+
kv_transfer_params=meta.kv_transfer_params,
142+
)
143+
patched_recv[remote_id] = patched_meta
144+
else:
145+
patched_recv[local_id] = meta
146+
147+
self._reqs_need_recv = patched_recv
148+
super().group_kv_pull(metadata)
149+
150+
# Restore unconsumed entries to original local keys
151+
for remote_id, local_id in list(self.remote_to_local_req.items()):
152+
if remote_id in self._reqs_need_recv:
153+
entry = self._reqs_need_recv.pop(remote_id)
154+
self._reqs_need_recv[local_id] = original_recv.get(local_id, entry)
155+
156+
def receive_kv(self, path: Any = None, req_blocks: Any = None) -> Any:
157+
result = super().receive_kv(path, req_blocks)
158+
159+
if self.remote_to_local_req:
160+
completed = [
161+
rid
162+
for rid, lid in self.remote_to_local_req.items()
163+
if not hasattr(self, "_reqs_need_recv") or lid not in self._reqs_need_recv
164+
]
165+
for remote_id in completed:
166+
self.remote_to_local_req.pop(remote_id, None)
167+
168+
return result
169+
170+
PatchedMooncakeConnector.__qualname__ = _OriginalMooncakeConnector.__qualname__
171+
172+
return PatchedMooncakeConnector
173+
174+
175+
def apply_mooncake_connector_patch() -> bool:
176+
"""Replace vLLM's MooncakeConnector with the patched version."""
177+
global _patched
178+
if _patched:
179+
return True
180+
181+
_mc_module = _import_mooncake_module()
182+
if _mc_module is None:
183+
logger.warning("[monkey_patch] Cannot import MooncakeConnector — patch NOT applied.")
184+
return False
185+
186+
_OriginalClass = _mc_module.MooncakeConnector
187+
188+
PatchedClass = _create_patched_mooncake_connector()
189+
190+
_mc_module.MooncakeConnector = PatchedClass
191+
for _, module in sys.modules.items():
192+
if hasattr(module, "MooncakeConnector") and module.MooncakeConnector is _OriginalClass:
193+
module.MooncakeConnector = PatchedClass
194+
195+
_patched = True
196+
logger.info("[monkey_patch] MooncakeConnector patch applied")
197+
return True

0 commit comments

Comments
 (0)