Skip to content

Commit 324c755

Browse files
zucchini-nlppaulpak58
authored andcommitted
[feat] Enable mm caching for transformers backend (vllm-project#21358)
Signed-off-by: raushan <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent abc3291 commit 324c755

File tree

4 files changed

+7
-18
lines changed

4 files changed

+7
-18
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ These models are what we list in [supported-text-models][supported-text-models]
1818

1919
### Transformers
2020

21-
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs, and require setting `--disable_mm_preprocessor_cache` when running. Support for video inputs and caching of multi-modal preprocessors will be added in future releases.
21+
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs. Support for video inputs will be added in future releases.
2222

2323
To check if the modeling backend is Transformers, you can simply do this:
2424

tests/models/multimodal/generation/test_common.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,6 @@
186186
image_size_factors=[(0.25, 0.5, 1.0)],
187187
vllm_runner_kwargs={
188188
"model_impl": "transformers",
189-
"disable_mm_preprocessor_cache": True,
190-
"enable_prefix_caching": False,
191189
},
192190
marks=[pytest.mark.core_model],
193191
),
@@ -205,8 +203,6 @@
205203
# image_size_factors=[(0.25, 0.5, 1.0)],
206204
# vllm_runner_kwargs={
207205
# "model_impl": "transformers",
208-
# "disable_mm_preprocessor_cache": True,
209-
# "enable_prefix_caching": False,
210206
# },
211207
# marks=[pytest.mark.core_model],
212208
# ),
@@ -223,8 +219,6 @@
223219
image_size_factors=[(0.25, 0.2, 0.15)],
224220
vllm_runner_kwargs={
225221
"model_impl": "transformers",
226-
"disable_mm_preprocessor_cache": True,
227-
"enable_prefix_caching": False,
228222
},
229223
marks=[large_gpu_mark(min_gb=32)],
230224
),
@@ -239,8 +233,6 @@
239233
image_size_factors=[(0.25, 0.5, 1.0)],
240234
vllm_runner_kwargs={
241235
"model_impl": "auto",
242-
"disable_mm_preprocessor_cache": True,
243-
"enable_prefix_caching": False,
244236
},
245237
auto_cls=AutoModelForImageTextToText,
246238
marks=[pytest.mark.core_model],

vllm/model_executor/models/transformers.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,6 @@ def apply(
315315
Apply HF Processor on prompt text and multi-modal data together,
316316
outputting token IDs and processed tensors.
317317
"""
318-
if return_mm_hashes:
319-
raise ValueError(
320-
"TransformersForMultimodalLM doesn't support mm hashing yet! "
321-
"Probably you didn't set `disable_mm_preprocessor_cache=True`")
322-
323318
if tokenization_kwargs is None:
324319
tokenization_kwargs = {}
325320

@@ -375,12 +370,14 @@ def apply(
375370
num_image_patches),
376371
)
377372

373+
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
374+
tokenization_kwargs)
378375
return MultiModalInputs(
379376
type="multimodal",
380377
prompt=prompt,
381378
prompt_token_ids=prompt_ids,
382379
mm_kwargs=mm_kwargs,
383-
mm_hashes=None,
380+
mm_hashes=mm_hashes,
384381
mm_placeholders=mm_placeholders,
385382
)
386383

vllm/v1/core/kv_cache_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,9 @@ def need_extra_keys(request: Request) -> bool:
406406
# Multimodal requests need to include the MM hash.
407407
# LoRA requests need to include the LoRA ID.
408408
# Request with provided cache salt need to include the salt.
409-
return bool(request.mm_positions) or (request.lora_request
410-
is not None) or (request.cache_salt
411-
is not None)
409+
return bool(request.mm_hashes) or (request.lora_request
410+
is not None) or (request.cache_salt
411+
is not None)
412412

413413

414414
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,

0 commit comments

Comments
 (0)