Skip to content

Commit 7d6fb90

Browse files
[Model] Use merge_by_field_config for MM models (A-C) (vllm-project#26073)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 418d111 commit 7d6fb90

File tree

5 files changed

+29
-24
lines changed

5 files changed

+29
-24
lines changed

vllm/model_executor/models/aria.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Iterable, Mapping, Sequence
4-
from typing import Annotated, Optional, Union
4+
from typing import Annotated, Literal, Optional, Union
55

66
import torch
77
import torch.nn as nn
@@ -38,8 +38,8 @@
3838
# yapf: enable
3939
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
4040
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
41-
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
42-
is_pp_missing_parameter, maybe_prefix)
41+
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
42+
maybe_prefix)
4343

4444

4545
class AriaImagePixelInputs(TensorSchema):
@@ -52,6 +52,8 @@ class AriaImagePixelInputs(TensorSchema):
5252
- w: Width of each image
5353
"""
5454

55+
type: Literal["pixel_values"]
56+
5557
pixel_values: Annotated[
5658
torch.Tensor,
5759
TensorShape("bn", 3, "h", "w"),
@@ -485,6 +487,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
485487
This model combines a vision tower, a multi-modal projector, and a language
486488
model to perform tasks that involve both image and text inputs.
487489
"""
490+
merge_by_field_config = True
491+
488492
hf_to_vllm_mapper = WeightsMapper(
489493
orig_to_new_prefix={
490494
# mapping for new names in checkpoint saved after transformers v4.52
@@ -551,12 +555,15 @@ def _parse_and_validate_image_input(
551555
return None
552556

553557
return AriaImagePixelInputs(
554-
pixel_values=flatten_bn(pixel_values, concat=True),
555-
pixel_mask=flatten_bn(pixel_mask, concat=True),
558+
type="pixel_values",
559+
pixel_values=pixel_values,
560+
pixel_mask=pixel_mask,
556561
)
557562

558563
def _create_patch_attention_mask(
559-
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
564+
self,
565+
pixel_mask: Optional[torch.Tensor],
566+
) -> Optional[torch.Tensor]:
560567
if pixel_mask is None:
561568
return None
562569

vllm/model_executor/models/aya_vision.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
3333
from .siglip import SiglipVisionModel
34-
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
34+
from .utils import (AutoWeightsLoader, WeightsMapper,
3535
init_vllm_registered_model, maybe_prefix)
3636

3737

@@ -295,6 +295,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
295295
dummy_inputs=AyaVisionDummyInputsBuilder)
296296
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
297297
SupportsPP):
298+
merge_by_field_config = True
298299

299300
hf_to_vllm_mapper = WeightsMapper(
300301
orig_to_new_prefix={
@@ -379,8 +380,8 @@ def _parse_and_validate_image_input(
379380

380381
return AyaVisionImagePixelInputs(
381382
type="pixel_values",
382-
pixel_values=flatten_bn(pixel_values, concat=True),
383-
num_patches=flatten_bn(num_patches, concat=True),
383+
pixel_values=pixel_values,
384+
num_patches=num_patches,
384385
resolve_bindings={
385386
"h": self.config.vision_config.image_size,
386387
"w": self.config.vision_config.image_size,

vllm/model_executor/models/blip2.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,7 @@
2626
from .blip import BlipVisionModel
2727
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
2828
SupportsQuant)
29-
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
30-
maybe_prefix)
31-
32-
# We use this internally as placeholders since there is no image token
33-
# defined on the HuggingFace repo
34-
_IMAGE_TOKEN_ID = 50265
29+
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
3530

3631

3732
class Blip2ImagePixelInputs(TensorSchema):
@@ -514,6 +509,7 @@ def _get_prompt_updates(
514509
dummy_inputs=Blip2DummyInputsBuilder)
515510
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
516511
SupportsQuant):
512+
merge_by_field_config = True
517513

518514
@classmethod
519515
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -570,8 +566,7 @@ def _parse_and_validate_image_input(
570566
if pixel_values is not None:
571567
expected_h = expected_w = self.config.vision_config.image_size
572568
return Blip2ImagePixelInputs(type="pixel_values",
573-
data=flatten_bn(pixel_values,
574-
concat=True),
569+
data=pixel_values,
575570
resolve_bindings={
576571
"h": expected_h,
577572
"w": expected_w
@@ -580,7 +575,7 @@ def _parse_and_validate_image_input(
580575
if image_embeds is not None:
581576
return Blip2ImageEmbeddingInputs(
582577
type="image_embeds",
583-
data=flatten_bn(image_embeds, concat=True),
578+
data=image_embeds,
584579
)
585580

586581
raise AssertionError("This line should be unreachable.")

vllm/model_executor/models/chameleon.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
4444
SupportsQuant)
45-
from .utils import (flatten_bn, is_pp_missing_parameter,
45+
from .utils import (is_pp_missing_parameter,
4646
make_empty_intermediate_tensors_factory, make_layers,
4747
maybe_prefix)
4848

@@ -935,6 +935,8 @@ def forward(
935935
dummy_inputs=ChameleonDummyInputsBuilder)
936936
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
937937
SupportsPP, SupportsQuant):
938+
merge_by_field_config = True
939+
938940
packed_modules_mapping = {
939941
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
940942
"gate_up_proj": ["gate_proj", "up_proj"]
@@ -981,8 +983,7 @@ def _parse_and_validate_image_input(
981983
expected_h = expected_w = vq_config.resolution
982984

983985
return ChameleonImagePixelInputs(type="pixel_values",
984-
data=flatten_bn(pixel_values,
985-
concat=True),
986+
data=pixel_values,
986987
resolve_bindings={
987988
"h": expected_h,
988989
"w": expected_w

vllm/model_executor/models/cohere2_vision.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
3838
from .siglip import SiglipVisionModel
39-
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
39+
from .utils import (AutoWeightsLoader, WeightsMapper,
4040
init_vllm_registered_model, maybe_prefix)
4141

4242

@@ -317,6 +317,7 @@ def get_replacement(item_idx: int):
317317
dummy_inputs=Cohere2VisionDummyInputsBuilder)
318318
class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
319319
SupportsPP):
320+
merge_by_field_config = True
320321

321322
hf_to_vllm_mapper = WeightsMapper(
322323
orig_to_new_prefix={
@@ -399,8 +400,8 @@ def _parse_and_validate_image_input(
399400

400401
return Cohere2VisionImagePixelInputs(
401402
type="pixel_values",
402-
pixel_values=flatten_bn(pixel_values, concat=True),
403-
num_patches=flatten_bn(num_patches, concat=True),
403+
pixel_values=pixel_values,
404+
num_patches=num_patches,
404405
resolve_bindings={
405406
"h": self.config.vision_config.image_size,
406407
"w": self.config.vision_config.image_size,

0 commit comments

Comments
 (0)