Skip to content

Commit 0ad9951

Browse files
[Input] Remove unused prompt field (vllm-project#26097)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 8c91171 commit 0ad9951

File tree

15 files changed

+67
-101
lines changed

15 files changed

+67
-101
lines changed

tests/models/multimodal/processing/test_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ def test_multimodal_processor(model_id):
3737
hf_processor_mm_kwargs={},
3838
)
3939

40-
assert str_processed_inputs["prompt"] == ids_processed_inputs["prompt"]
40+
assert (str_processed_inputs["prompt_token_ids"]
41+
== ids_processed_inputs["prompt_token_ids"])

vllm/engine/protocol.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,15 @@ async def beam_search(
9494
# this happens again in generation, so the double expansion causes
9595
# a mismatch.
9696
# TODO - would be ideal to handle this more gracefully.
97-
prompt_token_ids = prompt.get("prompt_token_ids")
98-
multi_modal_data = prompt.get("multi_modal_data")
97+
if isinstance(prompt, str):
98+
prompt_text = prompt
99+
prompt_token_ids = []
100+
multi_modal_data = None
101+
else:
102+
prompt_text = prompt.get("prompt")
103+
prompt_token_ids = prompt.get("prompt_token_ids", [])
104+
multi_modal_data = prompt.get("multi_modal_data")
99105

100-
prompt_text = processed_inputs.get("prompt")
101106
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
102107

103108
tokenized_length = len(prompt_token_ids)

vllm/inputs/data.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,6 @@ class TokenInputs(TypedDict):
205205
prompt_token_ids: list[int]
206206
"""The token IDs of the prompt."""
207207

208-
prompt: NotRequired[str]
209-
"""
210-
The original prompt text corresponding to the token IDs, if available.
211-
"""
212-
213208
cache_salt: NotRequired[str]
214209
"""
215210
Optional cache salt to be used for prefix caching.
@@ -218,15 +213,12 @@ class TokenInputs(TypedDict):
218213

219214
def token_inputs(
220215
prompt_token_ids: list[int],
221-
prompt: Optional[str] = None,
222216
cache_salt: Optional[str] = None,
223217
) -> TokenInputs:
224218
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
225219
values."""
226220
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
227221

228-
if prompt is not None:
229-
inputs["prompt"] = prompt
230222
if cache_salt is not None:
231223
inputs["cache_salt"] = cache_salt
232224

vllm/inputs/preprocess.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from vllm.transformers_utils.tokenizer import AnyTokenizer
1717

1818
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
19-
EncoderDecoderInputs, ProcessorInputs, PromptType,
20-
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
21-
TokensPrompt, embeds_inputs, token_inputs)
19+
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
20+
ProcessorInputs, PromptType, SingletonInputs,
21+
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
22+
embeds_inputs, token_inputs)
2223
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
2324

2425
logger = init_logger(__name__)
@@ -322,7 +323,7 @@ def _process_tokens(
322323
mm_uuids=mm_uuids,
323324
)
324325
else:
325-
inputs = token_inputs(prompt_token_ids=prompt_token_ids)
326+
inputs = token_inputs(prompt_token_ids)
326327

327328
if cache_salt := parsed_content.get("cache_salt"):
328329
inputs["cache_salt"] = cache_salt
@@ -352,10 +353,7 @@ def _process_text(
352353
prompt_text,
353354
tokenization_kwargs=tokenization_kwargs,
354355
)
355-
inputs = token_inputs(
356-
prompt=prompt_text,
357-
prompt_token_ids=prompt_token_ids,
358-
)
356+
inputs = token_inputs(prompt_token_ids)
359357

360358
if cache_salt := parsed_content.get("cache_salt"):
361359
inputs["cache_salt"] = cache_salt
@@ -473,22 +471,17 @@ def _split_enc_dec_mm_inputs(
473471
decoder_inputs: SingletonInputs
474472

475473
if inputs["type"] == "multimodal": # Multimodal data inputs
476-
if not ("encoder_prompt" in inputs
477-
and "encoder_prompt_token_ids" in inputs):
474+
if "encoder_prompt_token_ids" not in inputs:
478475
raise RuntimeError("You should register an encoder-decoder "
479476
"multi-modal processor for encoder-decoder "
480477
"models.")
481478
inputs = cast(MultiModalEncDecInputs, inputs)
482479

483-
encoder_inputs = token_inputs(
484-
prompt=inputs["encoder_prompt"],
485-
prompt_token_ids=inputs["encoder_prompt_token_ids"],
486-
)
480+
encoder_inputs = token_inputs(inputs["encoder_prompt_token_ids"])
487481

488482
decoder_prompt_inputs = decoder_inputs_to_override or inputs
489483
decoder_inputs = MultiModalInputs(
490484
type="multimodal",
491-
prompt=decoder_prompt_inputs.get("prompt", ""),
492485
prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
493486
mm_kwargs=inputs["mm_kwargs"],
494487
mm_hashes=inputs["mm_hashes"],
@@ -498,7 +491,7 @@ def _split_enc_dec_mm_inputs(
498491
decoder_inputs["cache_salt"] = cache_salt
499492

500493
elif inputs["type"] == "token": # Text-only inputs
501-
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
494+
encoder_inputs = token_inputs(prompt_token_ids=[])
502495
decoder_inputs = decoder_inputs_to_override or inputs
503496
else:
504497
assert_never(inputs) # type: ignore[arg-type]
@@ -549,12 +542,14 @@ def _process_encoder_decoder_prompt(
549542
decoder_inputs: Optional[SingletonInputs]
550543

551544
if is_explicit_encoder_decoder_prompt(prompt):
545+
# `cast` is needed for mypy, but not pyright
546+
prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt)
552547
encoder_inputs = self._prompt_to_llm_inputs(
553-
prompt["encoder_prompt"],
548+
prompt_["encoder_prompt"],
554549
tokenization_kwargs=tokenization_kwargs,
555550
mm_uuids=mm_uuids,
556551
)
557-
if (decoder_input := prompt["decoder_prompt"]) is None:
552+
if (decoder_input := prompt_["decoder_prompt"]) is None:
558553
decoder_inputs = None
559554
else:
560555
decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
@@ -565,8 +560,9 @@ def _process_encoder_decoder_prompt(
565560
self._split_enc_dec_mm_inputs(encoder_inputs,
566561
decoder_inputs))
567562
else:
563+
# `cast` is needed for mypy, but not pyright
568564
inputs = self._prompt_to_llm_inputs(
569-
prompt,
565+
cast(SingletonPrompt, prompt),
570566
tokenization_kwargs=tokenization_kwargs,
571567
mm_uuids=mm_uuids,
572568
)
@@ -641,8 +637,9 @@ def preprocess(
641637
"to decoder-only models")
642638

643639
# Decoder-only operation
640+
# `cast` is needed for mypy, but not pyright
644641
return self._process_decoder_only_prompt(
645-
prompt,
642+
cast(SingletonPrompt, prompt),
646643
tokenization_kwargs=tokenization_kwargs,
647644
mm_uuids=mm_uuids,
648645
)

vllm/model_executor/models/llava.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def get_replacement_mantis(item_idx: int):
778778
)
779779
], mm_item_counts)
780780

781-
prompt_ids, prompt, _ = self._apply_prompt_updates(
781+
prompt_ids, _ = self._apply_prompt_updates(
782782
result["prompt_token_ids"],
783783
mantis_mm_repls,
784784
)
@@ -798,7 +798,6 @@ def get_replacement_mantis(item_idx: int):
798798

799799
return MultiModalInputs(
800800
type="multimodal",
801-
prompt=prompt,
802801
prompt_token_ids=prompt_ids,
803802
mm_kwargs=mm_kwargs,
804803
mm_hashes=mm_hashes,

vllm/model_executor/models/paligemma.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ def apply(
219219
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
220220
prompt_token_ids.append(newline_token_id)
221221
mm_inputs["prompt_token_ids"] = prompt_token_ids
222-
mm_inputs["prompt"] += newline_prompt
223222

224223
return mm_inputs
225224

vllm/model_executor/models/phi3v.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def _apply_prompt_updates(
461461
self,
462462
token_ids: list[int],
463463
mm_prompt_updates: MultiModalPromptUpdates,
464-
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
464+
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
465465
# align to hf behavior when there are images
466466
if len(mm_prompt_updates):
467467
tokenizer = self.info.get_tokenizer()
@@ -496,14 +496,14 @@ def _apply_prompt_updates(
496496
for ele in sublist for e in ele
497497
]
498498

499-
token_ids, text, placeholders = super()._apply_prompt_updates(
499+
token_ids, placeholders = super()._apply_prompt_updates(
500500
token_ids=token_ids,
501501
mm_prompt_updates=mm_prompt_updates,
502502
)
503503

504504
# Keep the behavior in line with HF processor
505-
if text.startswith("<s> <|image|>"):
506-
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
505+
if token_ids[:2] == tokenizer.encode("<s> <|image|>",
506+
add_special_tokens=False):
507507
token_ids = [token_ids[0], *token_ids[2:]]
508508
placeholders = {
509509
modality: [
@@ -518,7 +518,7 @@ def _apply_prompt_updates(
518518
for modality, ps in placeholders.items()
519519
}
520520

521-
return token_ids, text, placeholders
521+
return token_ids, placeholders
522522

523523

524524
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,

vllm/model_executor/models/qwen2_5_omni_thinker.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
PromptReplacement, PromptUpdate)
6464
from vllm.multimodal.profiling import BaseDummyInputsBuilder
6565
from vllm.sequence import IntermediateTensors
66-
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
66+
from vllm.transformers_utils.tokenizer import encode_tokens
6767
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6868

6969
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@@ -316,7 +316,7 @@ def _maybe_apply_prompt_updates(
316316
mm_kwargs: MultiModalKwargsItems,
317317
mm_prompt_updates: MultiModalPromptUpdates,
318318
is_update_applied: bool,
319-
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
319+
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
320320
"""
321321
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
322322
"""
@@ -341,28 +341,20 @@ def _maybe_apply_prompt_updates(
341341
self._validate_mm_placeholders(
342342
mm_placeholders,
343343
mm_item_counts,
344-
use_audio_in_video=use_audio_in_video)
345-
346-
tokenizer = self.info.get_tokenizer()
347-
prompt = decode_tokens(tokenizer, prompt_ids)
344+
use_audio_in_video=use_audio_in_video,
345+
)
348346
else:
349-
(
350-
prompt_ids,
351-
prompt,
352-
mm_placeholders,
353-
) = self._apply_prompt_updates(
347+
prompt_ids, mm_placeholders = self._apply_prompt_updates(
354348
prompt_ids,
355349
mm_prompt_updates,
356350
)
357351
self._validate_mm_placeholders(
358352
mm_placeholders,
359353
mm_item_counts,
360-
use_audio_in_video=use_audio_in_video)
361-
362-
tokenizer = self.info.get_tokenizer()
363-
prompt = decode_tokens(tokenizer, prompt_ids)
354+
use_audio_in_video=use_audio_in_video,
355+
)
364356

365-
return prompt_ids, prompt, mm_placeholders
357+
return prompt_ids, mm_placeholders
366358

367359
def _get_prompt_updates(
368360
self,

vllm/model_executor/models/terratorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def apply(
190190

191191
return MultiModalInputs(
192192
type="multimodal",
193-
prompt=prompt,
194193
prompt_token_ids=[1],
195194
mm_kwargs=mm_kwargs,
196195
mm_hashes=mm_hashes,

vllm/model_executor/models/transformers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,6 @@ def apply(
453453

454454
return MultiModalInputs(
455455
type="multimodal",
456-
prompt=prompt,
457456
prompt_token_ids=prompt_ids,
458457
mm_kwargs=mm_kwargs,
459458
mm_hashes=mm_hashes,

0 commit comments

Comments
 (0)