Skip to content

Commit bb1848c

Browse files
authored
[Model Runner V2] Support VLM (vllm-project#32546)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 6101a26 commit bb1848c

File tree

6 files changed

+263
-15
lines changed

6 files changed

+263
-15
lines changed

vllm/v1/worker/gpu/cudagraph_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def capture_graph(
7676
model: nn.Module,
7777
input_buffers: InputBuffers,
7878
mrope_positions: torch.Tensor | None,
79+
inputs_embeds: torch.Tensor | None,
7980
block_tables: BlockTables,
8081
attn_metadata_builders: list[AttentionMetadataBuilder],
8182
kv_cache_config: KVCacheConfig,
@@ -86,6 +87,8 @@ def capture_graph(
8687
if self.uses_mrope:
8788
assert mrope_positions is not None
8889
positions = mrope_positions[:, :num_tokens]
90+
if inputs_embeds is not None:
91+
inputs_embeds = inputs_embeds[:num_tokens]
8992
attn_metadata = prepare_inputs_to_capture(
9093
num_reqs,
9194
num_tokens,
@@ -108,6 +111,7 @@ def capture_graph(
108111
hidden_states = model(
109112
input_ids=input_ids,
110113
positions=positions,
114+
inputs_embeds=inputs_embeds,
111115
)
112116
if self.hidden_states is None:
113117
self.hidden_states = torch.empty_like(hidden_states)
@@ -128,6 +132,7 @@ def capture_graph(
128132
hidden_states = model(
129133
input_ids=input_ids,
130134
positions=positions,
135+
inputs_embeds=inputs_embeds,
131136
)
132137
self.hidden_states[:num_tokens] = hidden_states
133138
self.graphs[num_tokens] = graph
@@ -138,6 +143,7 @@ def capture(
138143
model: nn.Module,
139144
input_buffers: InputBuffers,
140145
mrope_positions: torch.Tensor | None,
146+
inputs_embeds: torch.Tensor | None,
141147
block_tables: BlockTables,
142148
attn_metadata_builders: list[AttentionMetadataBuilder],
143149
kv_cache_config: KVCacheConfig,
@@ -149,6 +155,7 @@ def capture(
149155
model=model,
150156
input_buffers=input_buffers,
151157
mrope_positions=mrope_positions,
158+
inputs_embeds=inputs_embeds,
152159
block_tables=block_tables,
153160
attn_metadata_builders=attn_metadata_builders,
154161
kv_cache_config=kv_cache_config,

vllm/v1/worker/gpu/input_batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ def __init__(
1515
self,
1616
max_num_reqs: int,
1717
max_num_tokens: int,
18-
inputs_embeds_size: int,
19-
vocab_size: int,
20-
dtype: torch.dtype,
2118
device: torch.device,
2219
):
2320
self.max_num_reqs = max_num_reqs
@@ -64,6 +61,8 @@ class InputBatch:
6461
positions: torch.Tensor
6562
# [3, num_tokens_after_padding]
6663
mrope_positions: torch.Tensor | None
64+
# [num_tokens_after_padding, hidden_size]
65+
inputs_embeds: torch.Tensor | None
6766

6867
# layer_name -> Metadata
6968
attn_metadata: dict[str, Any]
@@ -132,6 +131,7 @@ def make_dummy(
132131
input_ids=input_ids,
133132
positions=positions,
134133
mrope_positions=None,
134+
inputs_embeds=None,
135135
attn_metadata=None, # type: ignore
136136
logits_indices=logits_indices,
137137
cu_num_logits=cu_num_logits,
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import numpy as np
4+
import torch
5+
6+
from vllm.model_executor.models.interfaces import SupportsMultiModal
7+
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
8+
from vllm.multimodal.utils import group_mm_kwargs_by_modality
9+
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
10+
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
11+
12+
13+
class EncoderRunner:
14+
def __init__(
15+
self,
16+
max_num_tokens: int,
17+
hidden_size: int,
18+
dtype: torch.dtype,
19+
device: torch.device,
20+
):
21+
self.max_num_tokens = max_num_tokens
22+
self.hidden_size = hidden_size
23+
self.dtype = dtype
24+
self.device = device
25+
26+
self.inputs_embeds = torch.zeros(
27+
max_num_tokens,
28+
hidden_size,
29+
dtype=dtype,
30+
device=device,
31+
)
32+
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
33+
self.encoder_cache: dict[str, torch.Tensor] = {}
34+
35+
self.tmp_is_mm_embed = UvaBufferPool(max_num_tokens, torch.bool)
36+
37+
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
38+
self.req_id_to_mm_features[req_id] = mm_features
39+
40+
def free_encoder_cache(self, mm_hash: str) -> None:
41+
self.encoder_cache.pop(mm_hash, None)
42+
43+
def remove_request(self, req_id: str) -> None:
44+
self.req_id_to_mm_features.pop(req_id, None)
45+
46+
def prepare_mm_inputs(
47+
self,
48+
scheduled_encoder_inputs: dict[str, list[int]],
49+
) -> tuple[list[str], list[MultiModalKwargsItem]]:
50+
mm_hashes: list[str] = []
51+
mm_kwargs: list[MultiModalKwargsItem] = []
52+
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
53+
mm_features = self.req_id_to_mm_features[req_id]
54+
for mm_input_id in encoder_input_ids:
55+
mm_feature = mm_features[mm_input_id]
56+
if mm_feature.data is None:
57+
continue
58+
mm_hashes.append(mm_feature.identifier)
59+
mm_kwargs.append(mm_feature.data)
60+
return mm_hashes, mm_kwargs
61+
62+
@torch.inference_mode()
63+
def execute_mm_encoder(
64+
self,
65+
model: SupportsMultiModal,
66+
mm_hashes: list[str],
67+
mm_kwargs: list[MultiModalKwargsItem],
68+
) -> list[torch.Tensor]:
69+
if not mm_hashes:
70+
return []
71+
72+
encoder_outputs: list[torch.Tensor] = []
73+
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
74+
mm_kwargs,
75+
device=self.device,
76+
pin_memory=False,
77+
):
78+
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
79+
sanity_check_mm_encoder_outputs(
80+
curr_group_outputs,
81+
expected_num_items=num_items,
82+
)
83+
encoder_outputs.extend(curr_group_outputs)
84+
85+
# Cache the encoder outputs by mm_hash
86+
for mm_hash, output in zip(mm_hashes, encoder_outputs):
87+
self.encoder_cache[mm_hash] = output
88+
return encoder_outputs
89+
90+
def gather_mm_embeddings(
91+
self,
92+
req_ids: list[str],
93+
total_num_scheduled_tokens: int,
94+
num_scheduled_tokens: np.ndarray,
95+
query_start_loc: np.ndarray,
96+
prefill_lens: np.ndarray,
97+
computed_prefill_lens: np.ndarray,
98+
) -> tuple[list[torch.Tensor], torch.Tensor]:
99+
is_prefilling = (computed_prefill_lens < prefill_lens).tolist()
100+
all_decode = not any(is_prefilling)
101+
if all_decode:
102+
# All decode requests, so no need to gather any embeddings.
103+
return [], torch.zeros(
104+
total_num_scheduled_tokens,
105+
dtype=torch.bool,
106+
device=self.device,
107+
)
108+
109+
query_start = computed_prefill_lens.tolist()
110+
query_end = (computed_prefill_lens + num_scheduled_tokens).tolist()
111+
112+
mm_embeds: list[torch.Tensor] = []
113+
is_mm_embed = torch.zeros(
114+
total_num_scheduled_tokens,
115+
dtype=torch.bool,
116+
device="cpu",
117+
pin_memory=False,
118+
)
119+
for i, req_id in enumerate(req_ids):
120+
if not is_prefilling[i]:
121+
# OPTIMIZATION: Skip decode requests.
122+
continue
123+
124+
mm_features = self.req_id_to_mm_features[req_id]
125+
for mm_feature in mm_features:
126+
pos_info = mm_feature.mm_position
127+
start_pos = pos_info.offset
128+
num_encoder_tokens = pos_info.length
129+
130+
if start_pos >= query_end[i]:
131+
# The encoder output is not needed in this step.
132+
break
133+
if start_pos + num_encoder_tokens <= query_start[i]:
134+
# The encoder output is already processed and stored
135+
# in the decoder's KV cache.
136+
continue
137+
138+
start_idx = max(query_start[i] - start_pos, 0)
139+
end_idx = min(query_end[i] - start_pos, num_encoder_tokens)
140+
assert start_idx < end_idx
141+
curr_embeds_start, curr_embeds_end = (
142+
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
143+
)
144+
# If there are no embeddings in the current range, we skip
145+
# gathering the embeddings.
146+
if curr_embeds_start == curr_embeds_end:
147+
continue
148+
149+
mm_hash = mm_feature.identifier
150+
encoder_output = self.encoder_cache.get(mm_hash, None)
151+
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
152+
153+
if (is_embed := pos_info.is_embed) is not None:
154+
is_embed = is_embed[start_idx:end_idx]
155+
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
156+
else:
157+
mm_embeds_item = encoder_output[start_idx:end_idx]
158+
159+
req_start_pos = query_start_loc[i] + start_pos - query_start[i]
160+
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
161+
True if is_embed is None else is_embed
162+
)
163+
mm_embeds.append(mm_embeds_item)
164+
165+
# Copy the is_mm_embed tensor to the GPU.
166+
is_mm_embed = self.tmp_is_mm_embed.copy_to_gpu(is_mm_embed)
167+
return mm_embeds, is_mm_embed
168+
169+
@torch.inference_mode()
170+
def get_inputs_embeds(
171+
self,
172+
model: SupportsMultiModal,
173+
input_ids: torch.Tensor,
174+
mm_embeds: list[torch.Tensor],
175+
is_mm_embed: torch.Tensor,
176+
) -> torch.Tensor:
177+
x = model.embed_input_ids(
178+
input_ids,
179+
multimodal_embeddings=mm_embeds,
180+
is_multimodal=is_mm_embed,
181+
)
182+
# Copy to the pre-allocated buffer for CUDA graphs.
183+
self.inputs_embeds[: x.shape[0]] = x
184+
return self.inputs_embeds

vllm/v1/worker/gpu/mm/mrope_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
2424
# wasting a lot of CPU memory.
2525
self.prefill_mrope_positions = StagedWriteTensor(
26-
(max_num_reqs, 3 * max_model_len),
26+
(max_num_reqs * 3, max_model_len),
2727
dtype=torch.int32,
2828
device=device,
2929
uva_instead_of_gpu=True,
@@ -58,9 +58,7 @@ def init_prefill_mrope_positions(
5858
)
5959
for i in range(3):
6060
pos = prefill_mrope_positions[i].tolist()
61-
self.prefill_mrope_positions.stage_write(
62-
req_idx, i * self.max_model_len, pos
63-
)
61+
self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
6462
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
6563

6664
def apply_staged_writes(self) -> None:
@@ -79,7 +77,7 @@ def prepare_mrope_positions(
7977
self.mrope_positions,
8078
self.mrope_positions.stride(0),
8179
self.prefill_mrope_positions.gpu,
82-
self.prefill_mrope_positions.gpu.stride(0),
80+
3 * self.max_model_len,
8381
self.max_model_len,
8482
self.prefill_mrope_delta.gpu,
8583
idx_mapping,

0 commit comments

Comments
 (0)