Skip to content

Commit 39b643d

Browse files
[Model] Use merge_by_field_config for MM models (G) (vllm-project#26117)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 711f485 commit 39b643d

File tree

5 files changed

+56
-108
lines changed

5 files changed

+56
-108
lines changed

vllm/model_executor/models/gemma3_mm.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
3737
SupportsMultiModal, SupportsPP)
3838
from .siglip import SiglipVisionModel
39-
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
39+
from .utils import (AutoWeightsLoader, WeightsMapper,
4040
init_vllm_registered_model, maybe_prefix)
4141

4242
logger = init_logger(__name__)
@@ -289,7 +289,7 @@ def _call_hf_processor(
289289
processor=hf_processor)
290290
for size in image_sizes
291291
]
292-
processed_outputs["num_crops"] = torch.tensor(num_crops)
292+
processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
293293

294294
return processed_outputs
295295

@@ -298,12 +298,12 @@ def _get_mm_fields_config(
298298
hf_inputs: BatchFeature,
299299
hf_processor_mm_kwargs: Mapping[str, object],
300300
) -> Mapping[str, MultiModalFieldConfig]:
301-
num_crops = hf_inputs.get("num_crops", torch.empty(0))
301+
num_patches = hf_inputs.get("num_patches", torch.empty(0))
302302

303303
return dict(
304304
pixel_values=MultiModalFieldConfig.flat_from_sizes(
305-
"image", num_crops + 1),
306-
num_crops=MultiModalFieldConfig.batched("image"),
305+
"image", num_patches),
306+
num_patches=MultiModalFieldConfig.batched("image"),
307307
)
308308

309309
def _get_prompt_updates(
@@ -460,6 +460,8 @@ def forward(self, vision_outputs: torch.Tensor):
460460
dummy_inputs=Gemma3DummyInputsBuilder)
461461
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
462462
SupportsLoRA):
463+
merge_by_field_config = True
464+
463465
packed_modules_mapping = {
464466
"qkv_proj": [
465467
"q_proj",
@@ -526,29 +528,20 @@ def dtype(self):
526528
def _parse_and_validate_image_input(
527529
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
528530
pixel_values = kwargs.pop("pixel_values", None)
529-
num_crops = kwargs.pop("num_crops", None)
531+
num_patches = kwargs.pop("num_patches", None)
530532
image_embeds = kwargs.pop("image_embeds", None)
531533
assert image_embeds is None, "Gemma3 does not support image_embeds."
532534
if pixel_values is None:
533535
return None
534536

535-
if not isinstance(pixel_values, (torch.Tensor, list)):
536-
raise ValueError("Incorrect type of pixel values. "
537-
f"Got type: {type(pixel_values)}")
538-
539-
if not isinstance(num_crops, (torch.Tensor, list)):
540-
raise ValueError("Incorrect type of num_crops. "
541-
f"Got type: {type(num_crops)}")
542-
543537
image_size = self.config.vision_config.image_size
544538

545-
return Gemma3ImagePixelInputs(
546-
pixel_values=flatten_bn(pixel_values, concat=True),
547-
num_patches=flatten_bn(num_crops, concat=True) + 1,
548-
resolve_bindings={
549-
"h": image_size,
550-
"w": image_size
551-
})
539+
return Gemma3ImagePixelInputs(pixel_values=pixel_values,
540+
num_patches=num_patches,
541+
resolve_bindings={
542+
"h": image_size,
543+
"w": image_size
544+
})
552545

553546
def _image_pixels_to_features(
554547
self,

vllm/model_executor/models/gemma3n_mm.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Iterable, Mapping, Sequence
4-
from typing import Any, Literal, Optional, TypedDict, Union, cast
4+
from typing import Annotated, Any, Literal, Optional, Union, cast
55

66
import numpy as np
77
import torch
@@ -41,6 +41,7 @@
4141
# yapf: enable
4242
from vllm.multimodal.profiling import BaseDummyInputsBuilder
4343
from vllm.sequence import IntermediateTensors
44+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
4445

4546
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
4647
SupportsTranscription)
@@ -54,17 +55,28 @@
5455
TOKENS_PER_AUDIO = 188
5556

5657

57-
class Gemma3nImagePixelInputs(TypedDict):
58-
pixel_values: torch.Tensor
59-
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
58+
class Gemma3nImagePixelInputs(TensorSchema):
59+
"""
60+
Dimensions:
61+
- bn: Batch size * number of images
62+
- c: Number of channels (3)
63+
- h: Height of each patch
64+
- w: Width of each patch
65+
"""
66+
type: Literal["pixel_values"] = "pixel_values"
67+
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
6068

6169

62-
class Gemma3nAudioInputs(TypedDict):
63-
input_features: Union[torch.Tensor, list[torch.Tensor]]
64-
input_features_padded: torch.Tensor
65-
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
66-
input_features_mask: torch.Tensor
67-
"""Shape: `(batch_size * num_audio, seq_length)`"""
70+
class Gemma3nAudioInputs(TensorSchema):
71+
"""
72+
Dimensions:
73+
- bn: Batch size * number of audios
74+
- s: seq_length
75+
- f: num_features
76+
"""
77+
type: Literal["audio"] = "audio"
78+
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
79+
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
6880

6981

7082
Gemma3nImageInputs = Gemma3nImagePixelInputs
@@ -212,9 +224,9 @@ def _get_mm_fields_config(
212224

213225
return dict(
214226
pixel_values=MultiModalFieldConfig.batched("image"),
215-
input_features=MultiModalFieldConfig.batched("audio"),
216227
input_features_padded=MultiModalFieldConfig.batched("audio"),
217-
input_features_mask=MultiModalFieldConfig.batched("audio"))
228+
input_features_mask=MultiModalFieldConfig.batched("audio"),
229+
)
218230

219231
def _get_prompt_updates(
220232
self,
@@ -422,6 +434,7 @@ def forward(
422434
dummy_inputs=Gemma3nDummyInputsBuilder)
423435
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
424436
SupportsTranscription):
437+
merge_by_field_config = True
425438
supported_languages = ISO639_1_SUPPORTED_LANGS
426439

427440
packed_modules_mapping = {
@@ -482,14 +495,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
482495
device=self.language_model.model.embed_tokens.weight.device,
483496
dtype=self.language_model.model.embed_tokens.weight.dtype)
484497

485-
@property
486-
def dtype(self):
487-
return next(self.parameters()).dtype
488-
489-
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
490-
# TODO check if there are any
491-
return data
492-
493498
def _parse_and_validate_image_input(
494499
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
495500
pixel_values = kwargs.pop("pixel_values", None)
@@ -499,34 +504,22 @@ def _parse_and_validate_image_input(
499504
if pixel_values is None:
500505
return None
501506

502-
if not isinstance(pixel_values, (torch.Tensor, list)):
503-
raise ValueError("Incorrect type of pixel values. "
504-
f"Got type: {type(pixel_values)}")
505-
506-
pixel_values = flatten_bn(pixel_values, concat=True)
507-
pixel_values = pixel_values.contiguous()
508-
509-
return Gemma3nImagePixelInputs(
510-
pixel_values=self._validate_pixel_values(pixel_values), )
507+
return Gemma3nImagePixelInputs(pixel_values=pixel_values)
511508

512509
def _parse_and_validate_audio_input(
513510
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
514-
input_features = kwargs.pop("input_features", None)
515-
if input_features is None:
511+
512+
input_features_padded = kwargs.pop("input_features_padded", None)
513+
if input_features_padded is None:
516514
return None
517515

518516
input_features_mask = kwargs.pop("input_features_mask", None)
519517
if input_features_mask is None:
520518
return None
521519

522-
input_features_padded = kwargs.pop("input_features_padded", None)
523-
if input_features_padded is None:
524-
return None
525-
526520
return Gemma3nAudioInputs(
527-
input_features=input_features,
528-
input_features_mask=input_features_mask,
529521
input_features_padded=input_features_padded,
522+
input_features_mask=input_features_mask,
530523
)
531524

532525
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@@ -539,7 +532,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
539532
) and "image" not in mm_input_by_modality:
540533
mm_input_by_modality[
541534
"image"] = self._parse_and_validate_image_input(**kwargs)
542-
if input_key == "input_features" \
535+
if input_key == "input_features_padded" \
543536
and "audio" not in mm_input_by_modality:
544537
mm_input_by_modality[
545538
"audio"] = self._parse_and_validate_audio_input(**kwargs)

vllm/model_executor/models/glm4_1v.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,8 @@ def get_video_replacement_glm4v(item_idx: int):
13191319
)
13201320
class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
13211321
SupportsLoRA, SupportsPP):
1322+
merge_by_field_config = True
1323+
13221324
packed_modules_mapping = {
13231325
"qkv_proj": [
13241326
"q_proj",
@@ -1381,22 +1383,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
13811383
self.make_empty_intermediate_tensors = (
13821384
self.language_model.make_empty_intermediate_tensors)
13831385

1384-
def _validate_and_reshape_mm_tensor(self, mm_input: object,
1385-
name: str) -> torch.Tensor:
1386-
if not isinstance(mm_input, (torch.Tensor, list)):
1387-
raise ValueError(
1388-
f"Incorrect type of {name}. Got type: {type(mm_input)}")
1389-
if isinstance(mm_input, torch.Tensor):
1390-
if mm_input.ndim == 2:
1391-
return mm_input
1392-
if mm_input.ndim != 3:
1393-
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
1394-
f"Got ndim: {mm_input.ndim} "
1395-
f"(shape={mm_input.shape})")
1396-
return mm_input.reshape(-1, mm_input.shape[-1])
1397-
else:
1398-
return torch.concat(mm_input)
1399-
14001386
def _parse_and_validate_image_input(
14011387
self, **kwargs: object) -> Optional[Glm4vImageInputs]:
14021388
pixel_values = kwargs.pop("pixel_values", None)
@@ -1407,23 +1393,13 @@ def _parse_and_validate_image_input(
14071393
return None
14081394

14091395
if pixel_values is not None:
1410-
pixel_values = self._validate_and_reshape_mm_tensor(
1411-
pixel_values, "image pixel values")
1412-
image_grid_thw = self._validate_and_reshape_mm_tensor(
1413-
image_grid_thw, "image grid_thw")
1414-
14151396
return Glm4vImagePixelInputs(
14161397
type="pixel_values",
14171398
pixel_values=pixel_values,
14181399
image_grid_thw=image_grid_thw,
14191400
)
14201401

14211402
if image_embeds is not None:
1422-
image_embeds = self._validate_and_reshape_mm_tensor(
1423-
image_embeds, "image embeds")
1424-
image_grid_thw = self._validate_and_reshape_mm_tensor(
1425-
image_grid_thw, "image grid_thw")
1426-
14271403
return Glm4vImageEmbeddingInputs(
14281404
type="image_embeds",
14291405
image_embeds=image_embeds,
@@ -1440,23 +1416,13 @@ def _parse_and_validate_video_input(
14401416
return None
14411417

14421418
if pixel_values_videos is not None:
1443-
pixel_values_videos = self._validate_and_reshape_mm_tensor(
1444-
pixel_values_videos, "video pixel values")
1445-
video_grid_thw = self._validate_and_reshape_mm_tensor(
1446-
video_grid_thw, "video grid_thw")
1447-
14481419
return Glm4vVideoPixelInputs(
14491420
type="pixel_values_videos",
14501421
pixel_values_videos=pixel_values_videos,
14511422
video_grid_thw=video_grid_thw,
14521423
)
14531424

14541425
if video_embeds is not None:
1455-
video_embeds = self._validate_and_reshape_mm_tensor(
1456-
video_embeds, "video embeds")
1457-
video_grid_thw = self._validate_and_reshape_mm_tensor(
1458-
video_grid_thw, "video grid_thw")
1459-
14601426
return Glm4vVideoEmbeddingInputs(
14611427
type="video_embeds",
14621428
video_embeds=video_embeds,

vllm/model_executor/models/glm4v.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from .chatglm import ChatGLMBaseModel, ChatGLMModel
4444
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
4545
SupportsMultiModal, SupportsPP)
46-
from .utils import flatten_bn
4746

4847

4948
class GLMVImagePixelInputs(TensorSchema):
@@ -529,8 +528,9 @@ def get_replacement(item_idx: int):
529528
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
530529
info=GLM4VProcessingInfo,
531530
dummy_inputs=GLM4VDummyInputsBuilder)
532-
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
533-
SupportsMultiModal):
531+
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA,
532+
SupportsPP):
533+
merge_by_field_config = True
534534

535535
packed_modules_mapping = {
536536
"query_key_value": ["query_key_value"],
@@ -574,14 +574,9 @@ def _parse_and_validate_image_input(
574574
pixel_values = kwargs.pop("pixel_values", None)
575575

576576
if pixel_values is not None:
577-
if not isinstance(pixel_values, (torch.Tensor, list)):
578-
raise ValueError("Incorrect type of pixel values. "
579-
f"Got type: {type(pixel_values)}")
580-
581577
expected_h = expected_w = self.config.vision_config["image_size"]
582578
return GLMVImagePixelInputs(type="pixel_values",
583-
data=flatten_bn(pixel_values,
584-
concat=True),
579+
data=pixel_values,
585580
resolve_bindings={
586581
"h": expected_h,
587582
"w": expected_w
@@ -598,6 +593,8 @@ def _process_image_input(
598593
def get_language_model(self) -> torch.nn.Module:
599594
return self.transformer
600595

596+
get_input_embeddings = SupportsMultiModal.get_input_embeddings
597+
601598
def get_multimodal_embeddings(self,
602599
**kwargs: object) -> MultiModalEmbeddings:
603600
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/granite_speech.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,8 @@ def _call_hf_processor(
168168
# Calculate the number of audio tokens per entry in the batch;
169169
# This is used to split the batch back out after padding.
170170
audio_token_index = self.info.get_hf_config().audio_token_index
171-
processed_outputs["audio_embed_sizes"] = [
172-
torch.sum(indices == audio_token_index).item()
173-
for indices in processed_outputs["input_ids"]
174-
]
171+
processed_outputs["audio_embed_sizes"] = (
172+
processed_outputs["input_ids"] == audio_token_index).sum(-1)
175173

176174
return processed_outputs
177175

@@ -527,6 +525,7 @@ class GraniteSpeechForConditionalGeneration(
527525
SupportsPP,
528526
SupportsLoRA,
529527
):
528+
merge_by_field_config = True
530529

531530
packed_modules_mapping = {
532531
"qkv_proj": [

0 commit comments

Comments
 (0)