Skip to content

Commit 0377802

Browse files
authored
[Multimodal] Remove legacy multimodal fields in favor of MultiModalFeatureSpec (vllm-project#24548)
Signed-off-by: sfeng33 <[email protected]>
1 parent 72fc8aa commit 0377802

File tree

13 files changed

+102
-116
lines changed

13 files changed

+102
-116
lines changed

tests/v1/core/test_encoder_cache_manager.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
45
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
56

67

@@ -9,8 +10,17 @@ class MockRequest:
910

1011
def __init__(self, request_id, mm_hashes, token_counts):
1112
self.request_id = request_id
12-
self.mm_hashes = mm_hashes
1313
self._token_counts = token_counts
14+
self.mm_features = []
15+
for i, mm_hash in enumerate(mm_hashes):
16+
feature = MultiModalFeatureSpec(
17+
data=None,
18+
modality="image",
19+
identifier=mm_hash,
20+
mm_position=PlaceholderRange(offset=0,
21+
length=self._token_counts[i]),
22+
)
23+
self.mm_features.append(feature)
1424

1525
def get_num_encoder_tokens(self, input_id: int) -> int:
1626
return self._token_counts[input_id]

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
6464
NewRequestData(
6565
req_id=req_id,
6666
prompt_token_ids=[1, 2, 3],
67-
mm_kwargs=[],
68-
mm_hashes=[],
69-
mm_positions=[],
67+
mm_features=[],
7068
sampling_params=SamplingParams(),
7169
pooling_params=PoolingParams(),
7270
block_ids=([0], ), # block_ids should be tuple[list[int]]

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
203203
prompt_token_ids=prompt_token_ids,
204204
sampling_params=_create_sampling_params(),
205205
pooling_params=None,
206-
mm_kwargs=[],
207-
mm_positions=[],
208-
mm_hashes=[],
206+
mm_features=[],
209207
block_ids=([], ),
210208
generator=None,
211209
num_computed_tokens=len(output_token_ids),

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
118118
NewRequestData(
119119
req_id=req_id,
120120
prompt_token_ids=[1, 2, 3],
121-
mm_kwargs=[],
122-
mm_hashes=[],
123-
mm_positions=[],
121+
mm_features=[],
124122
sampling_params=SamplingParams(),
125123
pooling_params=None,
126124
block_ids=([0], ),

vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -300,23 +300,25 @@ def build_connector_meta(
300300
total_need_load = 0
301301
for new_req in scheduler_output.scheduled_new_reqs:
302302
if new_req.req_id in self._requests_need_load:
303-
meta.add_request(token_ids=new_req.prompt_token_ids,
304-
block_ids=new_req.block_ids[0],
305-
block_size=self._block_size,
306-
is_store=False,
307-
mm_hashes=new_req.mm_hashes)
303+
meta.add_request(
304+
token_ids=new_req.prompt_token_ids,
305+
block_ids=new_req.block_ids[0],
306+
block_size=self._block_size,
307+
is_store=False,
308+
mm_hashes=[f.identifier for f in new_req.mm_features])
308309
total_need_load += 1
309310
else:
310311
# NOTE: here, we set the store and load being exclusive,
311312
# but a single request can have both store and load.
312313
# NOTE(rob): for this debug implementation, we only cache
313314
# the original prompt tokens.
314315
if not self._found_match_for_request(new_req):
315-
meta.add_request(token_ids=new_req.prompt_token_ids,
316-
block_ids=new_req.block_ids[0],
317-
block_size=self._block_size,
318-
is_store=True,
319-
mm_hashes=new_req.mm_hashes)
316+
meta.add_request(
317+
token_ids=new_req.prompt_token_ids,
318+
block_ids=new_req.block_ids[0],
319+
block_size=self._block_size,
320+
is_store=True,
321+
mm_hashes=[f.identifier for f in new_req.mm_features])
320322

321323
cached_reqs = scheduler_output.scheduled_cached_reqs
322324
for i, req_id in enumerate(cached_reqs.req_ids):
@@ -341,11 +343,12 @@ def build_connector_meta(
341343
# of the block_ids for the request.
342344
block_ids = new_block_ids[0]
343345

344-
meta.add_request(token_ids=token_ids,
345-
block_ids=block_ids,
346-
block_size=self._block_size,
347-
is_store=False,
348-
mm_hashes=request.mm_hashes)
346+
meta.add_request(
347+
token_ids=token_ids,
348+
block_ids=block_ids,
349+
block_size=self._block_size,
350+
is_store=False,
351+
mm_hashes=[f.identifier for f in request.mm_features])
349352
total_need_load += 1
350353

351354
assert total_need_load == len(self._requests_need_load)
@@ -364,10 +367,10 @@ def _found_match_for_request(
364367
"""
365368
num_tokens_to_check = align_to_block_size(
366369
len(request.prompt_token_ids) - 1, self._block_size)
367-
foldername = self._generate_foldername_debug(torch.tensor(
368-
request.prompt_token_ids)[:num_tokens_to_check],
369-
request.mm_hashes,
370-
create_folder=False)
370+
foldername = self._generate_foldername_debug(
371+
torch.tensor(request.prompt_token_ids)[:num_tokens_to_check],
372+
[f.identifier for f in request.mm_features],
373+
create_folder=False)
371374
return os.path.exists(foldername)
372375

373376
def _generate_foldername_debug(

vllm/v1/core/encoder_cache_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def check_and_update_cache(self, request: Request, input_id: int) -> bool:
8686
Returns:
8787
True if the encoder output for this input is already cached
8888
"""
89-
mm_hash = request.mm_hashes[input_id]
89+
mm_hash = request.mm_features[input_id].identifier
9090
# Not cached at all
9191
if mm_hash not in self.cached:
9292
return False
@@ -167,7 +167,7 @@ def allocate(self, request: Request, input_id: int) -> None:
167167
This method assumes can_allocate() returned True for the same input.
168168
"""
169169

170-
mm_hash = request.mm_hashes[input_id]
170+
mm_hash = request.mm_features[input_id].identifier
171171
request_id = request.request_id
172172
if mm_hash not in self.cached:
173173
self.cached[mm_hash] = set()
@@ -193,8 +193,8 @@ def get_cached_input_ids(self, request: Request) -> set[int]:
193193
"""
194194
return {
195195
input_id
196-
for input_id in range(len(request.mm_hashes))
197-
if request.mm_hashes[input_id] in self.cached
196+
for input_id in range(len(request.mm_features))
197+
if request.mm_features[input_id].identifier in self.cached
198198
}
199199

200200
def free_encoder_input(self, request: Request, input_id: int) -> None:
@@ -208,7 +208,7 @@ def free_encoder_input(self, request: Request, input_id: int) -> None:
208208
`can_allocate`).
209209
"""
210210
req_id = request.request_id
211-
mm_hash = request.mm_hashes[input_id]
211+
mm_hash = request.mm_features[input_id].identifier
212212
# The mm_hash not in cache or the req_id set is empty
213213
if not self.cached.get(mm_hash, None):
214214
return

vllm/v1/core/kv_cache_utils.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,9 @@ def need_extra_keys(request: Request) -> bool:
418418
# Multimodal requests need to include the MM hash.
419419
# LoRA requests need to include the LoRA ID.
420420
# Request with provided cache salt need to include the salt.
421-
return bool(request.mm_hashes) or (request.lora_request
422-
is not None) or (request.cache_salt
423-
is not None)
421+
return bool(request.mm_features) or (request.lora_request
422+
is not None) or (request.cache_salt
423+
is not None)
424424

425425

426426
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
@@ -442,40 +442,36 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
442442
"""
443443
extra_keys: list[Any] = []
444444

445-
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
446-
if not mm_positions:
445+
mm_features = request.mm_features
446+
if not mm_features:
447447
return extra_keys, start_mm_idx
448448

449-
if mm_positions and len(mm_positions) != len(mm_hashes):
450-
raise ValueError(
451-
"The number of multi-modal positions and hashes must match. This "
452-
"is likely because you did not enable MM hashing. "
453-
"Please set `mm_processor_cache_gb > 0`.")
454-
455-
# Note that we assume mm_positions is sorted by offset.
449+
# Note that we assume mm_features are sorted by mm_position.offset.
456450
# We do not need to check all mm inputs if the start token index is out of
457451
# range. This usually happens in the late prefill phase and decoding phase.
458-
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
452+
last_pos = mm_features[-1].mm_position
453+
if last_pos.offset + last_pos.length < start_token_idx:
459454
return extra_keys, start_mm_idx
460455

461456
# Support start_mm_idx == -1 to indicate the last mm input.
462457
if start_mm_idx < 0:
463-
assert -start_mm_idx <= len(mm_positions)
464-
start_mm_idx = len(mm_positions) + start_mm_idx
458+
assert -start_mm_idx <= len(mm_features)
459+
start_mm_idx = len(mm_features) + start_mm_idx
465460

466461
curr_mm_idx = start_mm_idx
467-
while mm_positions and curr_mm_idx < len(mm_positions):
468-
assert mm_hashes[curr_mm_idx] is not None
469-
offset = mm_positions[curr_mm_idx].offset
470-
length = mm_positions[curr_mm_idx].length
462+
while mm_features and curr_mm_idx < len(mm_features):
463+
mm_feature = mm_features[curr_mm_idx]
464+
assert mm_feature.identifier is not None
465+
offset = mm_feature.mm_position.offset
466+
length = mm_feature.mm_position.length
471467
if end_token_idx > offset:
472468
if start_token_idx > offset + length:
473469
# This block has passed the current mm input.
474470
curr_mm_idx += 1
475471
continue
476472

477473
# The block contains the current mm input.
478-
extra_keys.append(mm_hashes[curr_mm_idx])
474+
extra_keys.append(mm_feature.identifier)
479475

480476
if end_token_idx >= offset + length:
481477
# If this block contains the end of the current mm input,

vllm/v1/core/sched/output.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1616
KVConnectorMetadata)
1717
from vllm.lora.request import LoRARequest
18-
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
18+
from vllm.multimodal.inputs import MultiModalFeatureSpec
1919
from vllm.pooling_params import PoolingParams
2020
from vllm.sampling_params import SamplingParams
2121
from vllm.v1.request import Request
@@ -27,9 +27,7 @@ class NewRequestData:
2727

2828
req_id: str
2929
prompt_token_ids: list[int]
30-
mm_kwargs: list[MultiModalKwargsItem]
31-
mm_hashes: list[str]
32-
mm_positions: list[PlaceholderRange]
30+
mm_features: list[MultiModalFeatureSpec]
3331
sampling_params: Optional[SamplingParams]
3432
pooling_params: Optional[PoolingParams]
3533
block_ids: tuple[list[int], ...]
@@ -45,9 +43,7 @@ def from_request(
4543
return cls(
4644
req_id=request.request_id,
4745
prompt_token_ids=request.prompt_token_ids,
48-
mm_kwargs=request.mm_kwargs,
49-
mm_hashes=request.mm_hashes,
50-
mm_positions=request.mm_positions,
46+
mm_features=request.mm_features,
5147
sampling_params=request.sampling_params,
5248
pooling_params=request.pooling_params,
5349
block_ids=block_ids,
@@ -59,9 +55,7 @@ def __repr__(self):
5955
return (f"NewRequestData("
6056
f"req_id={self.req_id},"
6157
f"prompt_token_ids={self.prompt_token_ids},"
62-
f"mm_kwargs={self.mm_kwargs},"
63-
f"mm_hashes={self.mm_hashes},"
64-
f"mm_positions={self.mm_positions},"
58+
f"mm_features={self.mm_features},"
6559
f"sampling_params={self.sampling_params},"
6660
f"block_ids={self.block_ids},"
6761
f"num_computed_tokens={self.num_computed_tokens},"
@@ -73,9 +67,7 @@ def anon_repr(self):
7367
return (f"NewRequestData("
7468
f"req_id={self.req_id},"
7569
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
76-
f"mm_kwargs={self.mm_kwargs},"
77-
f"mm_hashes={self.mm_hashes},"
78-
f"mm_positions={self.mm_positions},"
70+
f"mm_features={self.mm_features},"
7971
f"sampling_params={self.sampling_params},"
8072
f"block_ids={self.block_ids},"
8173
f"num_computed_tokens={self.num_computed_tokens},"

vllm/v1/core/sched/scheduler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -736,18 +736,18 @@ def _try_schedule_encoder_inputs(
736736
if num_new_tokens == 0 or not request.has_encoder_inputs:
737737
return [], num_new_tokens, encoder_compute_budget
738738
encoder_inputs_to_schedule: list[int] = []
739-
mm_positions = request.mm_positions
740-
assert mm_positions is not None
741-
assert len(mm_positions) > 0
739+
mm_features = request.mm_features
740+
assert mm_features is not None
741+
assert len(mm_features) > 0
742742

743743
# NOTE: since scheduler operates on the request level (possibly with
744744
# multiple encoder inputs per request), we need to create temporary
745745
# trackers for accounting at the encoder input level.
746746
mm_hashes_to_schedule = set()
747747
num_tokens_to_schedule = 0
748-
for i, pos_info in enumerate(mm_positions):
749-
start_pos = pos_info.offset
750-
num_encoder_tokens = pos_info.length
748+
for i, mm_feature in enumerate(mm_features):
749+
start_pos = mm_feature.mm_position.offset
750+
num_encoder_tokens = mm_feature.mm_position.length
751751

752752
# The encoder output is needed if the two ranges overlap:
753753
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@@ -778,7 +778,7 @@ def _try_schedule_encoder_inputs(
778778
if not self.is_encoder_decoder:
779779
# We are not using the encoder cache for encoder-decoder models,
780780
# yet.
781-
if request.mm_hashes[i] in mm_hashes_to_schedule:
781+
if request.mm_features[i].identifier in mm_hashes_to_schedule:
782782
# The same encoder input has already been scheduled in the
783783
# current step.
784784
continue
@@ -820,7 +820,7 @@ def _try_schedule_encoder_inputs(
820820

821821
num_tokens_to_schedule += num_encoder_tokens
822822
encoder_compute_budget -= num_encoder_tokens
823-
mm_hashes_to_schedule.add(request.mm_hashes[i])
823+
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
824824
encoder_inputs_to_schedule.append(i)
825825

826826
return (
@@ -1048,9 +1048,9 @@ def _free_encoder_inputs(self, request: Request) -> None:
10481048
# Here, we use list(set) to avoid modifying the set while iterating
10491049
# over it.
10501050
for input_id in list(cached_encoder_input_ids):
1051-
mm_positions = request.mm_positions[input_id]
1052-
start_pos = mm_positions.offset
1053-
num_tokens = mm_positions.length
1051+
mm_feature = request.mm_features[input_id]
1052+
start_pos = mm_feature.mm_position.offset
1053+
num_tokens = mm_feature.mm_position.length
10541054
if self.is_encoder_decoder and request.num_computed_tokens > 0:
10551055
# With Whisper, as soon as we've generated a single token,
10561056
# we know we're done with the encoder input. Cross Attention

vllm/v1/request.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,6 @@ def __init__(
9191
self.mm_features = mm_features or []
9292
self.num_encoder_inputs = len(self.mm_features)
9393
self.has_encoder_inputs = self.num_encoder_inputs > 0
94-
# TODO(sfeng33): Remove these legacy fields after clearing out all
95-
# references in scheduler and model runner
96-
self.mm_positions = [f.mm_position for f in self.mm_features]
97-
self.mm_kwargs = [f.data for f in self.mm_features]
98-
self.mm_hashes = [f.identifier for f in self.mm_features]
9994

10095
# Read-only views
10196
# Prevent directly appending to these lists since
@@ -180,8 +175,8 @@ def get_finished_reason(self) -> Union[FinishReason, None]:
180175
return RequestStatus.get_finished_reason(self.status)
181176

182177
def get_num_encoder_tokens(self, input_id: int) -> int:
183-
assert input_id < len(self.mm_positions)
184-
num_tokens = self.mm_positions[input_id].length
178+
assert input_id < len(self.mm_features)
179+
num_tokens = self.mm_features[input_id].mm_position.length
185180
return num_tokens
186181

187182
def record_event(

0 commit comments

Comments
 (0)