Skip to content

Commit 9324e10

Browse files
Fix KV sharing fast prefill with cudagraph enabled (vllm-project#28537)
Signed-off-by: Yong Hoon Shin <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 4516d44 commit 9324e10

File tree

3 files changed

+17
-57
lines changed

3 files changed

+17
-57
lines changed

tests/v1/e2e/test_kv_sharing_fast_prefill.py

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
import random
55

66
import pytest
7-
import torch
87

98
from vllm import LLM, SamplingParams
109
from vllm.config import CompilationConfig, CompilationMode
11-
from vllm.distributed import cleanup_dist_env_and_memory
1210

13-
from ...utils import fork_new_process_for_each_test
11+
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
1412

1513
# global seed
1614
SEED = 42
@@ -45,28 +43,12 @@ def test_prompts():
4543
return prompts
4644

4745

48-
def cleanup(llm: LLM, compilation_config: CompilationConfig):
49-
# hacky: below lines are required to free up memory for the next test
50-
# when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
51-
# TODO(sarckk): when enforce_eager=False, memory is not freed:
52-
# find out why and re-enable test for enforce_eager=False case
53-
llm_engine = llm.llm_engine.engine_core.engine_core
54-
model_runner = llm_engine.model_executor.driver_worker.worker.model_runner
55-
del model_runner.model
56-
del model_runner.kv_caches
57-
del compilation_config.static_forward_context
58-
compilation_config.static_forward_context = {}
59-
60-
del llm
61-
torch.cuda.empty_cache()
62-
cleanup_dist_env_and_memory()
63-
64-
6546
@fork_new_process_for_each_test
66-
@pytest.mark.parametrize("enforce_eager", [True])
67-
@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill")
47+
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
48+
@pytest.mark.parametrize("enforce_eager", [True, False])
6849
def test_kv_sharing_fast_prefill(
6950
monkeypatch: pytest.MonkeyPatch,
51+
kv_sharing_fast_prefill: bool,
7052
enforce_eager: bool,
7153
test_prompts: list[str],
7254
):
@@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill(
7961
if not enforce_eager
8062
else CompilationMode.NONE,
8163
)
64+
batch_size = 10
8265

8366
with monkeypatch.context() as m:
8467
# Make scheduling deterministic for reproducibility
8568
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
8669

87-
llm = LLM(
88-
model="google/gemma-3n-E2B-it",
89-
enforce_eager=enforce_eager,
90-
compilation_config=compilation_config,
91-
seed=SEED,
92-
)
93-
ref_responses = llm.generate(test_prompts, sampling_params)
94-
95-
cleanup(llm, compilation_config)
70+
prompts, answer, indices = prep_prompts(batch_size)
9671

9772
llm = LLM(
9873
model="google/gemma-3n-E2B-it",
9974
enforce_eager=enforce_eager,
10075
compilation_config=compilation_config,
10176
seed=SEED,
102-
kv_sharing_fast_prefill=True,
77+
kv_sharing_fast_prefill=kv_sharing_fast_prefill,
78+
)
79+
responses = llm.generate(prompts, sampling_params)
80+
check_answers(
81+
indices,
82+
answer,
83+
[response.outputs[0].text for response in responses],
84+
accept_rate=1.0,
10385
)
104-
optimized_responses = llm.generate(test_prompts, sampling_params)
105-
106-
cleanup(llm, compilation_config)
107-
108-
misses = 0
109-
110-
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
111-
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
112-
misses += 1
113-
114-
assert misses == 0

vllm/v1/attention/backends/utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -965,12 +965,6 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens
965965
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
966966

967967

968-
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
969-
("logits_indices_padded", torch.Tensor | None, None),
970-
("num_logits_indices", int, 0),
971-
]
972-
973-
974968
def subclass_attention_metadata(
975969
name_prefix: str,
976970
metadata_cls: Any,
@@ -986,8 +980,8 @@ def subclass_attention_metadata(
986980

987981
@runtime_checkable
988982
class KVSharingFastPrefillMetadata(Protocol):
989-
logits_indices_padded: torch.Tensor
990-
num_logits_indices: int
983+
logits_indices_padded: torch.Tensor | None = None
984+
num_logits_indices: int | None = None
991985

992986

993987
def create_fast_prefill_custom_backend(
@@ -1019,11 +1013,6 @@ def __init__(self, metadata, common_attn_metadata):
10191013
for _field in fields(metadata.__class__):
10201014
setattr(self, _field.name, getattr(metadata, _field.name))
10211015

1022-
# Set additional fields that will be used in model code
1023-
assert (
1024-
common_attn_metadata.logits_indices_padded is not None
1025-
and common_attn_metadata.num_logits_indices is not None
1026-
)
10271016
self.logits_indices_padded = (
10281017
common_attn_metadata.logits_indices_padded
10291018
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,7 @@ def _build_attention_metadata(
13141314
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
13151315
"""
13161316
logits_indices_padded = None
1317-
num_logits_indices = 0
1317+
num_logits_indices = None
13181318
if logits_indices is not None:
13191319
num_logits_indices = logits_indices.size(0)
13201320
if self.cache_config.kv_sharing_fast_prefill:

0 commit comments

Comments
 (0)