Skip to content

Commit 69f4635

Browse files
authored
[Multimodal] Consolidate mm inputs into MultiModalFeatureSpec (vllm-project#23779)
Signed-off-by: sfeng33 <[email protected]>
1 parent d9e00db commit 69f4635

16 files changed

+143
-146
lines changed

tests/tokenization/test_detokenize.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer,
6464
request = EngineCoreRequest("",
6565
prompt_token_ids,
6666
None,
67-
None,
68-
None,
6967
params,
7068
None,
7169
None,

tests/v1/core/test_kv_cache_utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import torch
88

99
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
10-
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
10+
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
11+
MultiModalKwargsItem, PlaceholderRange)
1112
from vllm.sampling_params import SamplingParams
1213
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
1314
from vllm.v1.core.kv_cache_manager import KVCacheManager
@@ -37,17 +38,20 @@ def make_request(
3738
mm_hashes: Optional[list[str]] = None,
3839
cache_salt: Optional[str] = None,
3940
):
40-
if mm_positions is None:
41-
mm_kwargs = None
42-
else:
43-
mm_item = MultiModalKwargsItem.dummy("dummy_m")
44-
mm_kwargs = [mm_item] * len(mm_positions)
41+
mm_features = []
42+
if mm_positions is not None:
43+
for j, position in enumerate(mm_positions):
44+
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
45+
mm_feature = MultiModalFeatureSpec(
46+
data=MultiModalKwargsItem.dummy("dummy_m"),
47+
mm_position=position,
48+
identifier=identifier,
49+
modality="image")
50+
mm_features.append(mm_feature)
4551

4652
return Request(request_id=request_id,
4753
prompt_token_ids=prompt_token_ids,
48-
multi_modal_kwargs=mm_kwargs,
49-
multi_modal_hashes=mm_hashes,
50-
multi_modal_placeholders=mm_positions,
54+
mm_features=mm_features if mm_features else None,
5155
sampling_params=SamplingParams(max_tokens=17),
5256
pooling_params=None,
5357
eos_token_id=100,

tests/v1/core/test_prefix_caching.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch
1010

1111
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
12-
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
12+
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
13+
MultiModalKwargsItem, PlaceholderRange)
1314
from vllm.sampling_params import SamplingParams
1415
from vllm.utils import sha256, sha256_cbor_64bit
1516
from vllm.v1.core.block_pool import BlockPool
@@ -32,17 +33,20 @@ def make_request(
3233
prompt_logprobs: Optional[int] = None,
3334
cache_salt: Optional[str] = None,
3435
):
35-
if mm_positions is None:
36-
mm_kwargs = None
37-
else:
38-
mm_item = MultiModalKwargsItem.dummy("dummy_m")
39-
mm_kwargs = [mm_item] * len(mm_positions)
36+
mm_features = []
37+
if mm_positions is not None:
38+
for j, position in enumerate(mm_positions):
39+
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
40+
mm_feature = MultiModalFeatureSpec(
41+
data=MultiModalKwargsItem.dummy("dummy_m"),
42+
mm_position=position,
43+
identifier=identifier,
44+
modality="image")
45+
mm_features.append(mm_feature)
4046

4147
return Request(request_id=request_id,
4248
prompt_token_ids=prompt_token_ids,
43-
multi_modal_kwargs=mm_kwargs,
44-
multi_modal_hashes=mm_hashes,
45-
multi_modal_placeholders=mm_positions,
49+
mm_features=mm_features if mm_features else None,
4650
sampling_params=SamplingParams(
4751
max_tokens=17, prompt_logprobs=prompt_logprobs),
4852
pooling_params=None,

tests/v1/core/test_scheduler.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
1010
SchedulerConfig, SpeculativeConfig, VllmConfig)
11-
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
11+
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
12+
MultiModalKwargsItem, PlaceholderRange)
1213
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1314
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
1415
from vllm.v1.core.sched.scheduler import Scheduler
@@ -1308,21 +1309,24 @@ def create_requests_with_priority(
13081309
prompt_logprobs=prompt_logprobs)
13091310
requests = []
13101311
for i in range(num_requests):
1312+
mm_features = []
13111313
if mm_positions is not None:
13121314
mm_position = mm_positions[i]
1313-
mm_item = MultiModalKwargsItem.dummy("dummy_m")
1314-
mm_kwargs = [mm_item] * len(mm_position)
1315-
else:
1316-
mm_position = None
1317-
mm_kwargs = None
1315+
for j, position in enumerate(mm_position):
1316+
identifier = f"hash{i}_{j}"
1317+
mm_feature = MultiModalFeatureSpec(
1318+
data=MultiModalKwargsItem.dummy("dummy_m"),
1319+
mm_position=position,
1320+
identifier=identifier,
1321+
modality="image")
1322+
mm_features.append(mm_feature)
1323+
13181324
request = Request(
13191325
request_id=f"{i + starting_idx}",
13201326
prompt_token_ids=[i + starting_idx] * num_tokens,
13211327
sampling_params=sampling_params,
13221328
pooling_params=None,
1323-
multi_modal_kwargs=mm_kwargs,
1324-
multi_modal_placeholders=mm_position,
1325-
multi_modal_hashes=None,
1329+
mm_features=mm_features if mm_features else None,
13261330
eos_token_id=EOS_TOKEN_ID,
13271331
arrival_time=arrival_times[i],
13281332
priority=priorities[i],
@@ -1801,9 +1805,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
18011805
request = Request(
18021806
request_id="0",
18031807
prompt_token_ids=[0, 1],
1804-
multi_modal_kwargs=None,
1805-
multi_modal_hashes=None,
1806-
multi_modal_placeholders=None,
1808+
mm_features=None,
18071809
sampling_params=sampling_params,
18081810
pooling_params=None,
18091811
eos_token_id=EOS_TOKEN_ID,

tests/v1/core/utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
88
SchedulerConfig, SpeculativeConfig, VllmConfig)
9-
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
9+
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
10+
MultiModalKwargsItem, PlaceholderRange)
1011
from vllm.sampling_params import SamplingParams
1112
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
1213
init_none_hash)
@@ -139,29 +140,28 @@ def create_requests(
139140
prompt_logprobs=prompt_logprobs)
140141
requests = []
141142
for i in range(num_requests):
143+
mm_features = []
142144
if mm_positions is not None:
143145
mm_position = mm_positions[i]
144-
mm_item = MultiModalKwargsItem.dummy("dummy_m")
145-
mm_kwargs = [mm_item] * len(mm_position)
146-
# Dummy hash for each mm item should be unique
147-
# since encoder cache tracks entries by hash
148-
mm_hashes = [
149-
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
150-
]
151-
else:
152-
mm_position = None
153-
mm_kwargs = None
154-
mm_hashes = None
146+
for j, position in enumerate(mm_position):
147+
# Dummy hash for each mm item should be unique
148+
# since encoder cache tracks entries by hash
149+
identifier = f"hash{i}_{j}"
150+
mm_feature = MultiModalFeatureSpec(
151+
data=MultiModalKwargsItem.dummy("dummy_m"),
152+
mm_position=position,
153+
identifier=identifier,
154+
modality="image")
155+
mm_features.append(mm_feature)
156+
155157
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
156158
num_tokens)
157159
request = Request(
158160
request_id=f"{i}",
159161
prompt_token_ids=prompt_token_ids,
160162
sampling_params=sampling_params,
161163
pooling_params=None,
162-
multi_modal_kwargs=mm_kwargs,
163-
multi_modal_placeholders=mm_position,
164-
multi_modal_hashes=mm_hashes,
164+
mm_features=mm_features if mm_features else None,
165165
eos_token_id=EOS_TOKEN_ID,
166166
block_hasher=block_hasher,
167167
)

tests/v1/engine/test_engine_core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest:
3535
return EngineCoreRequest(
3636
request_id=str(uuid.uuid4()),
3737
prompt_token_ids=PROMPT_TOKENS,
38-
mm_kwargs=None,
39-
mm_hashes=None,
40-
mm_placeholders=None,
38+
mm_features=None,
4139
sampling_params=SamplingParams(),
4240
pooling_params=None,
4341
eos_token_id=None,

tests/v1/engine/test_engine_core_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ def make_request(
5252
return EngineCoreRequest(
5353
request_id=str(uuid.uuid4()),
5454
prompt_token_ids=prompt_tokens_ids,
55-
mm_kwargs=None,
56-
mm_hashes=None,
57-
mm_placeholders=None,
55+
mm_features=None,
5856
sampling_params=params,
5957
pooling_params=None,
6058
eos_token_id=None,

tests/v1/engine/test_fast_incdec_prefix_err.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case():
2626
prompt_token_ids = [107, 4606, 236787, 107]
2727
params = SamplingParams(skip_special_tokens=True)
2828
request = EngineCoreRequest(
29-
"test",
30-
prompt_token_ids,
31-
None,
32-
None,
33-
None,
34-
params,
35-
None,
36-
None,
37-
0.0,
38-
None,
29+
request_id="test",
30+
prompt_token_ids=prompt_token_ids,
31+
mm_features=None,
32+
sampling_params=params,
33+
pooling_params=None,
34+
eos_token_id=None,
35+
arrival_time=0.0,
36+
lora_request=None,
3937
cache_salt=None,
4038
data_parallel_rank=None,
4139
)

tests/v1/engine/test_output_processor.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
5252
requests = [
5353
EngineCoreRequest(request_id=f"request-{idx}",
5454
prompt_token_ids=prompt_tokens,
55-
arrival_time=0,
56-
mm_kwargs=None,
57-
mm_hashes=None,
58-
mm_placeholders=None,
55+
mm_features=None,
5956
eos_token_id=None,
57+
arrival_time=0,
6058
lora_request=None,
6159
cache_salt=None,
6260
data_parallel_rank=None,
@@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
401399
requests = [
402400
EngineCoreRequest(request_id=request_id_list[idx],
403401
prompt_token_ids=prompt_tokens,
404-
arrival_time=0,
405-
mm_kwargs=None,
406-
mm_hashes=None,
407-
mm_placeholders=None,
402+
mm_features=None,
408403
eos_token_id=None,
404+
arrival_time=0,
409405
lora_request=None,
410406
cache_salt=None,
411407
data_parallel_rank=None,
@@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool,
566562
request = EngineCoreRequest(
567563
request_id=request_id,
568564
prompt_token_ids=prompt_tokens,
569-
arrival_time=0,
570-
mm_kwargs=None,
571-
mm_hashes=None,
572-
mm_placeholders=None,
565+
mm_features=None,
573566
eos_token_id=eos_token_id,
567+
arrival_time=0,
574568
lora_request=None,
575569
cache_salt=None,
576570
data_parallel_rank=None,
@@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool,
665659
EngineCoreRequest(
666660
request_id=request_id_list[idx],
667661
prompt_token_ids=prompt_tokens,
668-
arrival_time=0,
669-
mm_kwargs=None,
670-
mm_hashes=None,
671-
mm_placeholders=None,
662+
mm_features=None,
672663
eos_token_id=None,
664+
arrival_time=0,
673665
lora_request=None,
674666
cache_salt=None,
675667
data_parallel_rank=None,
@@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors):
781773
EngineCoreRequest(
782774
request_id=f"request-{idx}",
783775
prompt_token_ids=prompt_tokens,
784-
arrival_time=0,
785-
mm_kwargs=None,
786-
mm_hashes=None,
787-
mm_placeholders=None,
776+
mm_features=None,
788777
eos_token_id=None,
778+
arrival_time=0,
789779
lora_request=None,
790780
cache_salt=None,
791781
data_parallel_rank=None,

tests/v1/kv_connector/unit/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@ def create_request(request_id: int,
162162
prompt_token_ids=prompt_token_ids,
163163
sampling_params=sampling_params,
164164
pooling_params=None,
165-
multi_modal_kwargs=None,
166-
multi_modal_placeholders=None,
167-
multi_modal_hashes=None,
165+
mm_features=None,
168166
eos_token_id=EOS_TOKEN_ID,
169167
block_hasher=get_request_block_hasher(block_size, hash_fn),
170168
)

0 commit comments

Comments
 (0)