Skip to content

Commit cc253b7

Browse files
[Model] Use merge_by_field_config for MM models (D-F) (vllm-project#26076)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 7d6fb90 commit cc253b7

File tree

4 files changed

+99
-177
lines changed

4 files changed

+99
-177
lines changed

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from vllm.model_executor.models.transformers import replace_linear_class
2121
from vllm.multimodal import MULTIMODAL_REGISTRY
2222
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
23-
MultiModalKwargsItems, MultiModalUUIDDict,
24-
NestedTensors)
23+
MultiModalKwargsItems, MultiModalUUIDDict)
2524
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
2625
ImageSize, MultiModalDataItems)
2726
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -40,7 +39,7 @@
4039
from vllm.utils.tensor_schema import TensorSchema, TensorShape
4140

4241
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
43-
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
42+
from .utils import (AutoWeightsLoader, WeightsMapper,
4443
init_vllm_registered_model, maybe_prefix)
4544

4645
# The image token id may be various
@@ -50,15 +49,15 @@
5049
class DeepseekVL2ImagePixelInputs(TensorSchema):
5150
"""
5251
Dimensions:
53-
- bn: Batch size * number of images
52+
- bnp: Batch size * number of images * number of patches
5453
- p: Number of patches
5554
- c: Number of channels (3)
5655
- h: Height of each image
5756
- w: Width of each image
5857
"""
5958
type: Literal["pixel_values"]
60-
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
61-
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
59+
data: Annotated[torch.Tensor,
60+
TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})]
6261
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
6362

6463

@@ -228,12 +227,8 @@ def _call_hf_processor(
228227
tok_kwargs=tok_kwargs,
229228
)
230229

231-
pixel_values = processed_outputs["pixel_values"]
232-
# split pixel values into patches corresponding to each image
233-
images_spatial_crop = processed_outputs["images_spatial_crop"]
234-
patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop]
235-
pixel_values = pixel_values.split(patches_per_image)
236-
processed_outputs["pixel_values"] = pixel_values
230+
processed_outputs["num_patches"] = (
231+
processed_outputs["images_spatial_crop"].prod(-1) + 1)
237232

238233
return processed_outputs
239234

@@ -242,8 +237,11 @@ def _get_mm_fields_config(
242237
hf_inputs: BatchFeature,
243238
hf_processor_mm_kwargs: Mapping[str, object],
244239
) -> Mapping[str, MultiModalFieldConfig]:
240+
num_patches = hf_inputs.get("num_patches", torch.empty(0))
241+
245242
return dict(
246-
pixel_values=MultiModalFieldConfig.batched("image"),
243+
pixel_values=MultiModalFieldConfig.flat_from_sizes(
244+
"image", num_patches),
247245
images_spatial_crop=MultiModalFieldConfig.batched("image"),
248246
image_embeds=MultiModalFieldConfig.batched("image"),
249247
)
@@ -318,6 +316,7 @@ def _cached_apply_hf_processor(
318316
info=DeepseekVL2ProcessingInfo,
319317
dummy_inputs=DeepseekVL2DummyInputsBuilder)
320318
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
319+
merge_by_field_config = True
321320

322321
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
323322
"language.": "language_model.",
@@ -460,37 +459,30 @@ def _parse_and_validate_image_input(
460459

461460
if pixel_values is not None:
462461
expected_h = expected_w = self.vision_config.image_size
463-
return DeepseekVL2ImagePixelInputs(type="pixel_values",
464-
data=flatten_bn(pixel_values),
465-
images_spatial_crop=flatten_bn(
466-
images_spatial_crop,
467-
concat=True),
468-
resolve_bindings={
469-
"h": expected_h,
470-
"w": expected_w,
471-
})
462+
return DeepseekVL2ImagePixelInputs(
463+
type="pixel_values",
464+
data=pixel_values,
465+
images_spatial_crop=images_spatial_crop,
466+
resolve_bindings={
467+
"h": expected_h,
468+
"w": expected_w,
469+
})
472470

473471
if image_embeds is not None:
474472
return DeepseekVL2VImageEmbeddingInputs(
475473
type="image_embeds",
476-
data=flatten_bn(image_embeds),
474+
data=image_embeds,
477475
)
478476

479477
raise AssertionError("This line should be unreachable.")
480478

481479
def _pixel_values_to_embedding(
482480
self,
483-
pixel_values: NestedTensors,
481+
pixel_values: torch.Tensor,
484482
images_spatial_crop: torch.Tensor,
485-
) -> NestedTensors:
486-
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
487-
total_tiles = [x for x in pixel_values]
488-
489-
# [batch_all_tiles, 3, height, width]
490-
total_tiles = torch.cat(total_tiles, dim=0)
491-
483+
) -> list[torch.Tensor]:
492484
# [batch_all_tiles, vit_seq_len, c]
493-
images_feature = self.vision.forward_features(total_tiles)
485+
images_feature = self.vision.forward_features(pixel_values)
494486

495487
# [batch_all_tiles, hw, D]
496488
images_embeds = self.projector(images_feature)
@@ -573,7 +565,7 @@ def _pixel_values_to_embedding(
573565
return vision_embeddings
574566

575567
def _process_image_input(
576-
self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor:
568+
self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]:
577569
if image_input["type"] == "image_embeds":
578570
image_data = image_input["data"]
579571
if is_list_of(image_data, torch.Tensor):

vllm/model_executor/models/dots_ocr.py

Lines changed: 24 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Iterable, Mapping
4-
from typing import Literal, Optional, TypedDict, Union
4+
from typing import Annotated, Literal, Optional, Union
55

66
import torch
77
import torch.nn as nn
@@ -42,34 +42,38 @@
4242
from vllm.sequence import IntermediateTensors
4343
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
4444
DotsVisionConfig)
45+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
4546

4647
from .vision import run_dp_sharded_mrope_vision_model
4748

4849
IMAGE_TOKEN = "<|imgpad|>"
4950

5051

51-
class DotsOCRImagePixelInputs(TypedDict):
52-
type: Literal["pixel_values", "image_grid_thw"]
52+
class DotsOCRImagePixelInputs(TensorSchema):
53+
"""
54+
Dimensions:
55+
- np: The total number of patches over each image over each prompt in
56+
the batch
57+
- ni: Number of images
58+
- cps: Number of channels * patch_size * patch_size
59+
"""
60+
type: Literal["pixel_values"]
5361

54-
pixel_values: torch.Tensor
55-
image_grid_thw: torch.Tensor
62+
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
63+
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
5664

5765

58-
class DotsOCRImageEmbeddingInputs(TypedDict):
59-
type: Literal["image_embeds", "image_grid_thw"]
60-
image_embeds: torch.Tensor
61-
"""Supported types:
62-
- List[`torch.Tensor`]: A list of tensors holding all images' features.
63-
Each tensor holds an image's features.
64-
- `torch.Tensor`: A tensor holding all images' features
65-
(concatenation of all images' feature tensors).
66-
Tensor shape: `(num_image_features, hidden_size)`
67-
- `num_image_features` varies based on
68-
the number and resolution of the images.
69-
- `hidden_size` must match the hidden size of language model backbone.
66+
class DotsOCRImageEmbeddingInputs(TensorSchema):
7067
"""
68+
Dimensions:
69+
- nf: Number of image features
70+
- hs: Hidden size
71+
- ni: Number of images
72+
"""
73+
type: Literal["image_embeds"]
7174

72-
image_grid_thw: torch.Tensor
75+
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
76+
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
7377

7478

7579
DotsOCRImageInputs = Union[DotsOCRImagePixelInputs,
@@ -654,6 +658,8 @@ def forward(self, hidden_states: torch.Tensor,
654658
)
655659
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
656660
SupportsLoRA):
661+
merge_by_field_config = True
662+
657663
hf_to_vllm_mapper = WeightsMapper(
658664
orig_to_new_substr={
659665
".attn.qkv_proj.": ".attn.qkv.",
@@ -709,22 +715,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
709715
architectures=["Qwen2ForCausalLM"],
710716
)
711717

712-
def _validate_and_reshape_mm_tensor(self, mm_input: object,
713-
name: str) -> torch.Tensor:
714-
if not isinstance(mm_input, (torch.Tensor, list)):
715-
raise ValueError(f"Incorrect type of {name}. "
716-
f"Got type: {type(mm_input)}")
717-
if isinstance(mm_input, torch.Tensor):
718-
if mm_input.ndim == 2:
719-
return mm_input
720-
if mm_input.ndim != 3:
721-
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
722-
f"Got ndim: {mm_input.ndim} "
723-
f"(shape={mm_input.shape})")
724-
return torch.concat(list(mm_input))
725-
else:
726-
return torch.concat(mm_input)
727-
728718
def _parse_and_validate_image_input(
729719
self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
730720
pixel_values = kwargs.pop("pixel_values", None)
@@ -735,28 +725,11 @@ def _parse_and_validate_image_input(
735725
return None
736726

737727
if pixel_values is not None:
738-
pixel_values = self._validate_and_reshape_mm_tensor(
739-
pixel_values, "image pixel values")
740-
image_grid_thw = self._validate_and_reshape_mm_tensor(
741-
image_grid_thw, "image grid_thw")
742-
743-
if not isinstance(pixel_values, (torch.Tensor, list)):
744-
raise ValueError("Incorrect type of image pixel values. "
745-
f"Got type: {type(pixel_values)}")
746-
747728
return DotsOCRImagePixelInputs(type="pixel_values",
748729
pixel_values=pixel_values,
749730
image_grid_thw=image_grid_thw)
750731

751732
if image_embeds is not None:
752-
image_embeds = self._validate_and_reshape_mm_tensor(
753-
image_embeds, "image embeds")
754-
image_grid_thw = self._validate_and_reshape_mm_tensor(
755-
image_grid_thw, "image grid_thw")
756-
757-
if not isinstance(image_embeds, torch.Tensor):
758-
raise ValueError("Incorrect type of image embeddings. "
759-
f"Got type: {type(image_embeds)}")
760733
return DotsOCRImageEmbeddingInputs(type="image_embeds",
761734
image_embeds=image_embeds,
762735
image_grid_thw=image_grid_thw)

vllm/model_executor/models/ernie45_vl.py

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import math
2626
from collections.abc import Iterable, Mapping, Sequence
2727
from functools import partial
28-
from typing import Any, Callable, Literal, Optional, TypedDict, Union
28+
from typing import Annotated, Any, Callable, Literal, Optional, Union
2929

3030
import numpy as np
3131
import torch
@@ -56,6 +56,7 @@
5656
from vllm.multimodal.profiling import BaseDummyInputsBuilder
5757
from vllm.platforms import _Backend, current_platform
5858
from vllm.sequence import IntermediateTensors
59+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
5960

6061
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
6162
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@@ -579,38 +580,38 @@ def load_weights(self, weights) -> set[str]:
579580
# === Vision Inputs === #
580581

581582

582-
class Ernie4_5_VLImagePixelInputs(TypedDict):
583-
type: Literal["pixel_values"]
584-
pixel_values: torch.Tensor
585-
"""Shape:
586-
`(num_patches, num_channels * patch_size * patch_size)`
583+
class Ernie4_5_VLImagePixelInputs(TensorSchema):
587584
"""
588-
589-
grid_thw: torch.Tensor
590-
"""Shape: `(num_images, 3)`
591-
This should be in `(grid_t, grid_h, grid_w)` format.
585+
Dimensions:
586+
- np: The total number of patches over each image over each prompt in
587+
the batch
588+
- ni: Number of images
589+
- cps: Number of channels * patch_size * patch_size
592590
"""
591+
type: Literal["pixel_values"]
592+
593+
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
594+
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
593595

594596

595597
Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs
596598

597599

598-
class Ernie4_5_VLVideoPixelInputs(TypedDict):
599-
type: Literal["pixel_values_videos"]
600-
pixel_values_videos: torch.Tensor
601-
"""Shape:
602-
`(num_patches,
603-
num_channels * temporal_patch_size * patch_size * patch_size)`
600+
class Ernie4_5_VLVideoPixelInputs(TensorSchema):
604601
"""
605-
606-
video_grid_thw: torch.Tensor
607-
"""Shape: `(num_videos, 3)`
608-
609-
This should be in `(grid_t, grid_h, grid_w)` format.
602+
Dimensions:
603+
- np: The total number of patches over each image over each prompt in
604+
the batch
605+
- ni: Number of images
606+
- cps: Number of channels * temporal_patch_size * patch_size *
607+
patch_size
610608
"""
609+
type: Literal["pixel_values_videos"]
610+
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")]
611+
video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
611612

612613

613-
Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs
614+
Ernie4_5_VLVideoInputs = Ernie4_5_VLVideoPixelInputs
614615

615616
# === Vision Processor === #
616617

@@ -1213,6 +1214,7 @@ def get_dummy_mm_data(
12131214
dummy_inputs=Ernie4_5_VLDummyInputsBuilder)
12141215
class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
12151216
SupportsLoRA, SupportsPP):
1217+
merge_by_field_config = True
12161218

12171219
packed_modules_mapping = {
12181220
"qkv_proj": [
@@ -1325,22 +1327,6 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
13251327
def get_language_model(self) -> torch.nn.Module:
13261328
return self.language_model
13271329

1328-
def _validate_and_reshape_mm_tensor(self, mm_input: object,
1329-
name: str) -> torch.Tensor:
1330-
if not isinstance(mm_input, (torch.Tensor, list)):
1331-
raise ValueError(f"Incorrect type of {name}. "
1332-
f"Got type: {type(mm_input)}")
1333-
if isinstance(mm_input, torch.Tensor):
1334-
if mm_input.ndim == 2:
1335-
return mm_input
1336-
if mm_input.ndim != 3:
1337-
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
1338-
f"Got ndim: {mm_input.ndim} "
1339-
f"(shape={mm_input.shape})")
1340-
return mm_input.reshape(-1, mm_input.shape[-1])
1341-
else:
1342-
return torch.concat(mm_input)
1343-
13441330
def _parse_and_validate_image_input(
13451331
self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]:
13461332
pixel_values = kwargs.pop("pixel_values", None)
@@ -1350,15 +1336,6 @@ def _parse_and_validate_image_input(
13501336
return None
13511337

13521338
if pixel_values is not None:
1353-
pixel_values = self._validate_and_reshape_mm_tensor(
1354-
pixel_values, "image pixel values")
1355-
image_grid_thw = self._validate_and_reshape_mm_tensor(
1356-
image_grid_thw, "image grid_thw")
1357-
1358-
if not isinstance(pixel_values, (torch.Tensor, list)):
1359-
raise ValueError("Incorrect type of image pixel values. "
1360-
f"Got type: {type(pixel_values)}")
1361-
13621339
return Ernie4_5_VLImagePixelInputs(type="pixel_values",
13631340
pixel_values=pixel_values,
13641341
image_grid_thw=image_grid_thw)
@@ -1372,11 +1349,6 @@ def _parse_and_validate_video_input(
13721349
return None
13731350

13741351
if pixel_values_videos is not None:
1375-
pixel_values_videos = self._validate_and_reshape_mm_tensor(
1376-
pixel_values_videos, "video pixel values")
1377-
video_grid_thw = self._validate_and_reshape_mm_tensor(
1378-
video_grid_thw, "video grid_thw")
1379-
13801352
return Ernie4_5_VLVideoPixelInputs(
13811353
type="pixel_values_videos",
13821354
pixel_values_videos=pixel_values_videos,

0 commit comments

Comments
 (0)