Skip to content

Commit cf0a99f

Browse files
authored
[ModelRunner V2] Support spec decode with structured outputs (vllm-project#33374)
Signed-off-by: Nick Hill <[email protected]>
1 parent e535d90 commit cf0a99f

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

vllm/v1/worker/gpu/input_batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class InputBatch:
7575
cu_num_logits: torch.Tensor
7676
cu_num_logits_np: np.ndarray
7777

78+
# Whether any requests in batch use structured output.
79+
has_structured_output_reqs: bool
80+
7881
@classmethod
7982
def make_dummy(
8083
cls,
@@ -139,6 +142,7 @@ def make_dummy(
139142
logits_indices=logits_indices,
140143
cu_num_logits=cu_num_logits,
141144
cu_num_logits_np=cu_num_logits_np,
145+
has_structured_output_reqs=False,
142146
)
143147

144148

vllm/v1/worker/gpu/model_runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
2121
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
2222
from vllm.v1.kv_cache_interface import KVCacheConfig
23-
from vllm.v1.outputs import ModelRunnerOutput
23+
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
2424
from vllm.v1.worker.gpu.async_utils import AsyncOutput
2525
from vllm.v1.worker.gpu.attn_utils import (
2626
build_attn_metadata,
@@ -59,6 +59,7 @@
5959
from vllm.v1.worker.gpu.sample.sampler import Sampler
6060
from vllm.v1.worker.gpu.spec_decode import init_speculator
6161
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
62+
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
6263
from vllm.v1.worker.gpu.states import RequestState
6364
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
6465
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -167,6 +168,10 @@ def __init__(
167168
# LoRA-related workers.
168169
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
169170

171+
# Draft tokens propagation - for spec-dec + struct outputs.
172+
self.draft_tokens_handler = DraftTokensHandler(self.device)
173+
174+
# KV Connector if configured.
170175
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
171176

172177
def update_max_model_len(self, max_model_len: int) -> None:
@@ -638,6 +643,7 @@ def prepare_inputs(
638643
logits_indices=logits_indices,
639644
cu_num_logits=cu_num_logits,
640645
cu_num_logits_np=cu_num_logits_np,
646+
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
641647
)
642648

643649
@torch.inference_mode()
@@ -938,7 +944,11 @@ def sample_tokens(
938944
num_rejected,
939945
)
940946
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
947+
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
941948

942949
if self.use_async_scheduling:
943950
return async_output
944951
return async_output.get_output()
952+
953+
def take_draft_token_ids(self) -> DraftTokenIds | None:
954+
return self.draft_tokens_handler.get_draft_tokens()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import numpy as np
4+
import torch
5+
6+
from vllm.v1.outputs import DraftTokenIds
7+
from vllm.v1.worker.gpu.async_utils import async_copy_to_np
8+
from vllm.v1.worker.gpu.input_batch import InputBatch
9+
10+
11+
class DraftTokensHandler:
12+
def __init__(self, device: torch.device | None = None):
13+
self.device = device
14+
self.copy_stream = torch.cuda.Stream(device)
15+
self.copy_event = torch.cuda.Event()
16+
17+
self.req_ids: list[str] = []
18+
self.draft_tokens_np: np.ndarray | None = None
19+
20+
def set_draft_tokens(
21+
self, input_batch: InputBatch, draft_tokens: torch.Tensor
22+
) -> None:
23+
if not input_batch.has_structured_output_reqs:
24+
# No draft token validation needs to be performed by
25+
# the scheduler for this batch.
26+
if self.req_ids:
27+
self.req_ids = []
28+
self.draft_tokens_np = None
29+
return
30+
31+
# For spec decoding + structured outputs, we must transfer the
32+
# draft tokens back to the scheduler for grammar validation.
33+
self.req_ids = input_batch.req_ids
34+
current_stream = torch.cuda.current_stream(self.device)
35+
self.copy_stream.wait_stream(current_stream)
36+
with torch.cuda.stream(self.copy_stream):
37+
self.draft_tokens_np = async_copy_to_np(draft_tokens)
38+
self.copy_event.record()
39+
40+
def get_draft_tokens(self) -> DraftTokenIds | None:
41+
if self.draft_tokens_np is None:
42+
return None
43+
44+
self.copy_event.synchronize()
45+
return DraftTokenIds(self.req_ids, self.draft_tokens_np.tolist())

0 commit comments

Comments
 (0)