Skip to content

Commit e37a89f

Browse files
[Feature]: Support cfg kv-cache transfer in multi-stage (vllm-project#1422)
Signed-off-by: princepride <wangzhipeng628@gmail.com> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 68764cc commit e37a89f

File tree

21 files changed

+805
-30
lines changed

21 files changed

+805
-30
lines changed

docs/configuration/stage_configs.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ Each stage in the `stage_args` list contains the following configuration options
135135

136136
A unique identifier for each stage in the multi-stage pipeline. Stages are numbered sequentially starting from 0, and this ID is used to reference stages in inter-stage dependencies (e.g., `engine_input_source`).
137137

138+
### `prompt_expand_func` (Optional)
139+
140+
A custom Python function hook for the LLM stage (Stage 0) that expands a single incoming prompt object into multiple prompts. This is primarily used for multi-modal Classifier-Free Guidance (CFG), where it generates the necessary companion requests (like a negative text prompt) and tags them with internal roles (e.g., `cfg_text`). This ensures the upstream LLM generates the needed contextual hidden states for both the conditional and unconditional generations simultaneously.
141+
142+
### `cfg_kv_collect_func` (Optional)
143+
144+
A custom Python function hook for downstream diffusion stages (Stage 1+) to collect, map, and process the KV caches transferred from the companion requests fired by `prompt_expand_func`. It aggregates the hidden condition states cleanly (e.g., binding them as `cfg_text_past_key_values` and `cfg_text_kv_metadata`), allowing the diffusion runtime to perform CFG smoothly without redundantly evaluating text paths on the DiT workers.
145+
138146
### `runtime`
139147

140148
Configuration for disaggregated execution of the stage, controlling how the stage is deployed and executed.

docs/design/architecture_overview.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ The framework achieves high performance through several optimization techniques:
9292
* **Quantization:** Supports various quantization implementations including FP8 and AWQ.
9393
* **FusedOps:** Allows for custom and third-party integration.
9494

95+
### Classifier-Free Guidance (CFG) Companion Flow
96+
97+
vLLM-Omni natively models Classifier-Free Guidance (CFG) across disaggregated multi-stage setups via a "companion request" paradigm, eliminating redundant textual/multimodal context computation boundaries:
98+
1. **Prompt Expansion:** In the initial autoregressive (AR) stage, a customized `prompt_expand_func` hook intercepts incoming generation prompts and pairs them directly with negative companion prompts (e.g., a default negative prompt) on the fly, tagging the secondary prompt with a specific internal role (`cfg_text`).
99+
2. **Synchronized KV Cache Transfer:** The AR stage evaluates both the primary and companion sequence batches concurrently. The `OmniConnector` captures these specific structural dependencies and reliably passes the positive and negative outcome KV caches seamlessly across stage boundaries via shared memory or network protocols.
100+
3. **KV Cache Collection & Injection:** Upon reaching the downstream Diffusion (DiT) Engine, an assigned `cfg_kv_collect_func` automatically intercepts the mapped companion caches (`cfg_text_past_key_values`). These auxiliary dependencies are natively gathered and seamlessly bound to the primary generation sequence variables, enabling the DiT Engine to cleanly implement cross-attention CFG guidance over accurate conditioning and unconditioning structures in parallel.
101+
95102
### Flexibility and Usability
96103

97104
vLLM-Omni is designed to be flexible and straightforward for users:

examples/offline_inference/bagel/end2end.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def parse_args():
4949
parser.add_argument("--cfg-text-scale", type=float, default=4.0, help="Text CFG scale (default: 4.0)")
5050
parser.add_argument("--cfg-img-scale", type=float, default=1.5, help="Image CFG scale (default: 1.5)")
5151
parser.add_argument(
52-
"--negative-prompt", type=str, default=None, help="Negative prompt (not yet supported, reserved for future)"
52+
"--negative-prompt", type=str, default=None, help="Negative prompt for CFG (default: empty prompt)"
5353
)
5454

5555
args = parser.parse_args()
@@ -162,6 +162,8 @@ def main():
162162
# text2img
163163
final_prompt_text = f"<|im_start|>{p}<|im_end|>"
164164
prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
165+
if args.negative_prompt is not None:
166+
prompt_dict["negative_prompt"] = args.negative_prompt
165167
formatted_prompts.append(prompt_dict)
166168

167169
params_list = omni.default_sampling_params_list
@@ -170,10 +172,13 @@ def main():
170172
if len(params_list) > 1:
171173
diffusion_params = params_list[1]
172174
diffusion_params.num_inference_steps = args.steps # type: ignore
173-
diffusion_params.extra_args = { # type: ignore
175+
extra = {
174176
"cfg_text_scale": args.cfg_text_scale,
175177
"cfg_img_scale": args.cfg_img_scale,
176178
}
179+
if args.negative_prompt is not None:
180+
extra["negative_prompt"] = args.negative_prompt
181+
diffusion_params.extra_args = extra # type: ignore
177182

178183
omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))
179184

tests/diffusion/test_diffusion_model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def _make_runner(cache_backend, cache_backend_name: str, enable_cache_dit_summar
5656
enable_cache_dit_summary=enable_cache_dit_summary,
5757
parallel_config=SimpleNamespace(use_hsdp=False),
5858
)
59-
runner.kv_transfer_manager = SimpleNamespace(receive_kv_cache=lambda req, target_device: None)
59+
runner.kv_transfer_manager = SimpleNamespace(
60+
receive_kv_cache=lambda req, target_device=None: None,
61+
receive_multi_kv_cache=lambda req, cfg_kv_collect_func=None, target_device=None: None,
62+
)
6063
return runner
6164

6265

tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
stage_args:
55
- stage_id: 0
66
stage_type: llm
7+
prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
78
runtime:
89
devices: "0"
910
max_batch_size: 1
@@ -39,6 +40,7 @@ stage_args:
3940
to_stage_1: mooncake_connector
4041
- stage_id: 1
4142
stage_type: diffusion
43+
cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
4244
runtime:
4345
devices: "0"
4446
max_batch_size: 1

tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
stage_args:
55
- stage_id: 0
66
stage_type: llm
7+
prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
78
runtime:
89
devices: "0"
910
max_batch_size: 1
@@ -38,6 +39,7 @@ stage_args:
3839

3940
- stage_id: 1
4041
stage_type: diffusion
42+
cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
4143
runtime:
4244
devices: "0"
4345
max_batch_size: 1

tests/e2e/offline_inference/test_bagel_text2img.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@
3737
# "Generated with seed=52, num_inference_steps=15,
3838
# prompt='A futuristic city skyline at twilight, cyberpunk style'"
3939
REFERENCE_PIXELS = [
40-
{"position": (100, 100), "rgb": (68, 107, 134)},
41-
{"position": (400, 50), "rgb": (95, 139, 166)},
42-
{"position": (700, 100), "rgb": (99, 122, 151)},
43-
{"position": (150, 400), "rgb": (111, 125, 153)},
44-
{"position": (512, 512), "rgb": (97, 107, 131)},
45-
{"position": (700, 400), "rgb": (48, 64, 98)},
46-
{"position": (100, 700), "rgb": (79, 63, 84)},
47-
{"position": (400, 700), "rgb": (40, 58, 79)},
48-
{"position": (700, 700), "rgb": (60, 75, 103)},
49-
{"position": (256, 256), "rgb": (97, 128, 156)},
40+
{"position": (100, 100), "rgb": (49, 96, 134)},
41+
{"position": (400, 50), "rgb": (63, 127, 167)},
42+
{"position": (700, 100), "rgb": (70, 101, 141)},
43+
{"position": (150, 400), "rgb": (115, 90, 150)},
44+
{"position": (512, 512), "rgb": (98, 86, 119)},
45+
{"position": (700, 400), "rgb": (29, 42, 91)},
46+
{"position": (100, 700), "rgb": (47, 50, 88)},
47+
{"position": (400, 700), "rgb": (36, 52, 91)},
48+
{"position": (700, 700), "rgb": (45, 58, 99)},
49+
{"position": (256, 256), "rgb": (62, 94, 135)},
5050
]
5151

5252
# Maximum allowed difference per color channel
@@ -80,6 +80,10 @@ def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_st
8080
params_list[0].max_tokens = max_tokens # type: ignore
8181
if len(params_list) > 1:
8282
params_list[1].num_inference_steps = num_inference_steps # type: ignore
83+
params_list[1].extra_args = { # type: ignore
84+
"cfg_text_scale": 4.0,
85+
"cfg_img_scale": 1.5,
86+
}
8387
return params_list
8488

8589

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import time
2+
from types import SimpleNamespace
3+
4+
import pytest
5+
6+
from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker
7+
8+
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
9+
10+
11+
def dummy_expand_func(prompt, sp0):
12+
if prompt == "expand_me":
13+
return [SimpleNamespace(prompt={"prompt": "neg"}, role="cfg_text", request_id_suffix="__cfg_text")]
14+
return []
15+
16+
17+
@pytest.fixture
18+
def tracker():
19+
sp0 = SimpleNamespace()
20+
return CfgCompanionTracker(prompt_expand_func=dummy_expand_func, stage0_sampling_params=sp0, timeout_s=0.1)
21+
22+
23+
def test_companion_tracker_initialization(tracker):
24+
assert not tracker.is_active
25+
assert tracker.num_companions == 0
26+
27+
28+
def test_expand_prompts_registers_companions(tracker):
29+
request_id_to_prompt = {"req1": "expand_me", "req2": "do_not_expand"}
30+
31+
pairs = tracker.expand_prompts(request_id_to_prompt)
32+
33+
assert len(pairs) == 1
34+
companion_id, prompt = pairs[0]
35+
assert companion_id == "req1__cfg_text"
36+
assert prompt == {"prompt": "neg"}
37+
38+
assert tracker.is_active
39+
assert tracker.num_companions == 1
40+
assert tracker.is_companion("req1__cfg_text")
41+
assert not tracker.is_companion("req2__cfg_text")
42+
assert tracker.has_companions("req1")
43+
assert not tracker.has_companions("req2")
44+
45+
comp_map = tracker.get_companion_request_ids("req1")
46+
assert comp_map == {"cfg_text": "req1__cfg_text"}
47+
48+
49+
def test_companion_lifecycle_success(tracker):
50+
request_id_to_prompt = {"req1": "expand_me"}
51+
tracker.expand_prompts(request_id_to_prompt)
52+
53+
# Defer parent
54+
engine_outputs = {"out": 123}
55+
tracker.defer_parent("req1", engine_outputs, stage_id=0)
56+
57+
# Initially not done
58+
assert not tracker.all_companions_done("req1")
59+
60+
# Companion completes
61+
parent_id = tracker.on_companion_completed("req1__cfg_text")
62+
63+
# Parent should be returned since all companions are done and it is pending
64+
assert parent_id == "req1"
65+
assert tracker.all_companions_done("req1")
66+
67+
# Pop pending parent
68+
popped = tracker.pop_pending_parent("req1")
69+
assert popped is not None
70+
assert popped["engine_outputs"] == engine_outputs
71+
assert popped["stage_id"] == 0
72+
73+
74+
def test_companion_lifecycle_failure(tracker):
75+
request_id_to_prompt = {"req1": "expand_me"}
76+
tracker.expand_prompts(request_id_to_prompt)
77+
78+
tracker.defer_parent("req1", {"out": 123}, stage_id=0)
79+
80+
# Companion fails
81+
parent_id, aborted = tracker.on_companion_error("req1__cfg_text")
82+
83+
assert parent_id == "req1"
84+
assert aborted is True
85+
assert tracker.is_parent_failed("req1")
86+
87+
# Parent should be removed from pending list
88+
assert tracker.pop_pending_parent("req1") is None
89+
90+
# Consume failure
91+
tracker.consume_parent_failure("req1")
92+
assert not tracker.is_parent_failed("req1")
93+
94+
95+
def test_companion_lifecycle_timeout(tracker):
96+
request_id_to_prompt = {"req1": "expand_me"}
97+
tracker.expand_prompts(request_id_to_prompt)
98+
99+
tracker.defer_parent("req1", {"out": 123}, stage_id=0)
100+
101+
# Initially no timeouts
102+
timeouts = tracker.check_timeouts()
103+
assert len(timeouts) == 0
104+
105+
# Wait for timeout
106+
time.sleep(0.15)
107+
108+
# Check timeouts again
109+
timeouts = tracker.check_timeouts()
110+
assert len(timeouts) == 1
111+
assert timeouts[0] == "req1"
112+
113+
# Should be removed from pending
114+
assert tracker.pop_pending_parent("req1") is None

vllm_omni/diffusion/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,9 @@ class OmniDiffusionConfig:
442442
# Omni configuration (injected from stage config)
443443
omni_kv_config: dict[str, Any] = field(default_factory=dict)
444444

445+
# Model-specific function for collecting CFG KV caches (set at runtime)
446+
cfg_kv_collect_func: Any | None = None
447+
445448
# Quantization settings
446449
# Supported methods: "fp8" (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs)
447450
quantization: str | None = None

vllm_omni/diffusion/models/bagel/pipeline_bagel.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -327,16 +327,40 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
327327
gen_context["past_key_values"] = injected_kv
328328
seq_len = injected_kv.key_cache[0].shape[0]
329329
gen_context["kv_lens"] = [seq_len]
330-
gen_context["ropes"] = [seq_len]
331-
332-
# Disable CFG: single KV cache cannot support 3-branch CFG
333-
logger.warning("CFG is disabled when using injected KV Cache")
334-
gen_params = BagelGenParams(
335-
num_timesteps=gen_params.num_timesteps,
336-
timestep_shift=gen_params.timestep_shift,
337-
cfg_text_scale=1.0,
338-
cfg_img_scale=1.0,
339-
)
330+
if req.sampling_params.kv_metadata and "ropes" in req.sampling_params.kv_metadata:
331+
gen_context["ropes"] = req.sampling_params.kv_metadata["ropes"]
332+
else:
333+
gen_context["ropes"] = [seq_len]
334+
335+
cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None)
336+
if cfg_text_kv is not None:
337+
logger.info("CFG enabled with multi-KV: using injected cfg_text KV Cache")
338+
cfg_text_seq_len = cfg_text_kv.key_cache[0].shape[0]
339+
cfg_text_context["past_key_values"] = cfg_text_kv
340+
cfg_text_context["kv_lens"] = [cfg_text_seq_len]
341+
cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None)
342+
if cfg_text_metadata and "ropes" in cfg_text_metadata:
343+
cfg_text_context["ropes"] = cfg_text_metadata["ropes"]
344+
else:
345+
cfg_text_context["ropes"] = [cfg_text_seq_len]
346+
347+
cfg_img_kv = getattr(req.sampling_params, "cfg_img_past_key_values", None) or injected_kv
348+
cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0]
349+
cfg_img_context["past_key_values"] = cfg_img_kv
350+
cfg_img_context["kv_lens"] = [cfg_img_seq_len]
351+
cfg_img_metadata = getattr(req.sampling_params, "cfg_img_kv_metadata", None)
352+
if cfg_img_metadata and "ropes" in cfg_img_metadata:
353+
cfg_img_context["ropes"] = cfg_img_metadata["ropes"]
354+
else:
355+
cfg_img_context["ropes"] = [cfg_img_seq_len]
356+
else:
357+
logger.warning("CFG is disabled: only single KV cache available")
358+
gen_params = BagelGenParams(
359+
num_timesteps=gen_params.num_timesteps,
360+
timestep_shift=gen_params.timestep_shift,
361+
cfg_text_scale=1.0,
362+
cfg_img_scale=1.0,
363+
)
340364

341365
else:
342366
image_input = (

0 commit comments

Comments
 (0)