Skip to content

Commit 671427e

Browse files
[Model] Move multimodal_cpu_fields definition to field config (vllm-project#30181)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 21bb323 commit 671427e

File tree

15 files changed

+141
-95
lines changed

15 files changed

+141
-95
lines changed

tests/distributed/test_shm_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int):
2828
modality=modality,
2929
key=key,
3030
data=torch.empty((size,), dtype=torch.int8),
31-
field=MultiModalSharedField(1),
31+
field=MultiModalSharedField(batch_size=1),
3232
)
3333

3434

tests/multimodal/test_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _dummy_elem(
5151
modality=modality,
5252
key=key,
5353
data=data,
54-
field=MultiModalSharedField(1),
54+
field=MultiModalSharedField(batch_size=1),
5555
)
5656

5757

tests/v1/test_serial_utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):
104104

105105
def test_multimodal_kwargs():
106106
e1 = MultiModalFieldElem(
107-
"audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()
107+
"audio",
108+
"a0",
109+
torch.zeros(1000, dtype=torch.bfloat16),
110+
MultiModalBatchedField(),
108111
)
109112
e2 = MultiModalFieldElem(
110113
"video",
111114
"v0",
112115
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
113-
MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0),
116+
MultiModalFlatField(
117+
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
118+
dim=0,
119+
),
114120
)
115121
e3 = MultiModalFieldElem(
116-
"image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4)
122+
"image",
123+
"i0",
124+
torch.zeros(1000, dtype=torch.int32),
125+
MultiModalSharedField(batch_size=4),
117126
)
118127
e4 = MultiModalFieldElem(
119128
"image",
120129
"i1",
121130
torch.zeros(1000, dtype=torch.int32),
122-
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2),
131+
MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
123132
)
124133
audio = MultiModalKwargsItem.from_elems([e1])
125134
video = MultiModalKwargsItem.from_elems([e2])
@@ -138,8 +147,8 @@ def test_multimodal_kwargs():
138147

139148
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
140149

141-
# expected total encoding length, should be 14306, +-20 for minor changes
142-
assert 14275 <= total_len <= 14325
150+
# expected total encoding length, should be 14395, +-20 for minor changes
151+
assert 14375 <= total_len <= 14425
143152
decoded = decoder.decode(encoded).mm[0]
144153
assert isinstance(decoded, MultiModalKwargsItems)
145154

vllm/model_executor/models/glm4_1v.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -787,10 +787,10 @@ def compute_attn_mask_seqlen(
787787
def forward(
788788
self,
789789
x: torch.Tensor,
790-
grid_thw: list[list[int]],
790+
grid_thw: torch.Tensor | list[list[int]],
791791
) -> torch.Tensor:
792-
# Convert grid_thw to tensor (always expecting list format now)
793-
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
792+
if isinstance(grid_thw, list):
793+
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
794794

795795
# patchify
796796
x = x.to(device=self.device, dtype=self.dtype)
@@ -805,7 +805,8 @@ def forward(
805805
cu_seqlens = torch.repeat_interleave(
806806
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
807807
).cumsum(dim=0, dtype=torch.int32)
808-
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
808+
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
809+
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
809810

810811
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
811812
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
@@ -1548,7 +1549,6 @@ def _process_image_input(
15481549
) -> tuple[torch.Tensor, ...]:
15491550
grid_thw = image_input["image_grid_thw"]
15501551
assert grid_thw.ndim == 2
1551-
grid_thw_list = grid_thw.tolist()
15521552

15531553
if image_input["type"] == "image_embeds":
15541554
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
@@ -1559,20 +1559,17 @@ def _process_image_input(
15591559
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
15601560
)
15611561
else:
1562-
image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist())
1562+
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1563+
15631564
merge_size = self.visual.spatial_merge_size
1564-
sizes = (
1565-
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
1566-
// (merge_size * merge_size)
1567-
).tolist()
1565+
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
15681566
return image_embeds.split(sizes)
15691567

15701568
def _process_video_input(
15711569
self, video_input: Glm4vVideoInputs
15721570
) -> tuple[torch.Tensor, ...]:
15731571
grid_thw = video_input["video_grid_thw"]
15741572
assert grid_thw.ndim == 2
1575-
grid_thw_list = grid_thw.tolist()
15761573

15771574
if video_input["type"] == "video_embeds":
15781575
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
@@ -1588,15 +1585,11 @@ def _process_video_input(
15881585
rope_type="rope_3d",
15891586
)
15901587
else:
1591-
video_embeds = self.visual(
1592-
pixel_values_videos, grid_thw=grid_thw.tolist()
1593-
)
1588+
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1589+
15941590
# Split concatenated embeddings for each video item.
15951591
merge_size = self.visual.spatial_merge_size
1596-
sizes = (
1597-
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
1598-
// (merge_size * merge_size)
1599-
).tolist()
1592+
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
16001593
return video_embeds.split(sizes)
16011594

16021595
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

vllm/model_executor/models/hunyuan_vision.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
563563
return dict(
564564
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
565565
image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
566-
image_grid_thw=MultiModalFieldConfig.batched("image"),
566+
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
567567
)
568568

569569

@@ -786,8 +786,6 @@ class HunYuanVLForConditionalGeneration(
786786
SupportsQuant,
787787
SupportsXDRoPE,
788788
):
789-
multimodal_cpu_fields = {"image_grid_thw"}
790-
791789
# To ensure correct weight loading and mapping.
792790
hf_to_vllm_mapper = WeightsMapper(
793791
orig_to_new_prefix={

vllm/model_executor/models/interfaces.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ class SupportsMultiModal(Protocol):
8484
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
8585
"""
8686

87-
multimodal_cpu_fields: ClassVar[Set[str]] = frozenset()
87+
multimodal_cpu_fields: ClassVar[Set[str] | None] = None
8888
"""
89-
A set indicating CPU-only multimodal fields.
89+
[DEPRECATED] A set indicating CPU-only multimodal fields.
9090
"""
9191

9292
_processor_factory: ClassVar[_ProcessorFactories]
@@ -279,6 +279,15 @@ def supports_multimodal(
279279
"please remove the override from your model."
280280
)
281281

282+
multimodal_cpu_fields = getattr(model, "multimodal_cpu_fields", None)
283+
if multimodal_cpu_fields is not None:
284+
raise ValueError(
285+
"`multimodal_cpu_fields` is no longer effective, "
286+
"please set `keep_on_cpu=True` in `MultiModalFieldConfig` "
287+
"(refer to https://github.com/vllm-project/vllm/pull/30181), "
288+
"and then remove the override from your model."
289+
)
290+
282291
return res
283292

284293

vllm/model_executor/models/opencua.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
201201
dummy_inputs=OpenCUADummyInputsBuilder,
202202
)
203203
class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
204-
multimodal_cpu_fields = {"image_grid_thw"}
205-
206204
packed_modules_mapping = {
207205
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
208206
"gate_up_proj": ["gate_proj", "up_proj"],

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,8 +1039,6 @@ class Qwen2_5_VLForConditionalGeneration(
10391039
SupportsMultiModalPruning,
10401040
SupportsMRoPE,
10411041
):
1042-
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1043-
10441042
packed_modules_mapping = {
10451043
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
10461044
"gate_up_proj": ["gate_proj", "up_proj"],

vllm/model_executor/models/qwen2_vl.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -811,14 +811,14 @@ def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
811811
image_embeds=MultiModalFieldConfig.flat_from_sizes(
812812
"image", image_embed_grid_sizes
813813
),
814-
image_grid_thw=MultiModalFieldConfig.batched("image"),
814+
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
815815
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
816816
"video", video_grid_sizes
817817
),
818818
video_embeds=MultiModalFieldConfig.flat_from_sizes(
819819
"video", video_embed_grid_sizes
820820
),
821-
video_grid_thw=MultiModalFieldConfig.batched("video"),
821+
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
822822
)
823823

824824
return _qwen2vl_field_config
@@ -1131,8 +1131,6 @@ def _get_mm_fields_config(
11311131
class Qwen2VLForConditionalGeneration(
11321132
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
11331133
):
1134-
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1135-
11361134
# To ensure correct weight loading and mapping.
11371135
hf_to_vllm_mapper = WeightsMapper(
11381136
orig_to_new_prefix={
@@ -1393,9 +1391,11 @@ def _process_video_input(
13931391
else:
13941392
pixel_values_videos = video_input["pixel_values_videos"]
13951393
if self.use_data_parallel:
1396-
grid_thw_list = grid_thw.tolist()
13971394
return run_dp_sharded_mrope_vision_model(
1398-
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
1395+
self.visual,
1396+
pixel_values_videos,
1397+
grid_thw.tolist(),
1398+
rope_type="rope_3d",
13991399
)
14001400
else:
14011401
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)

vllm/model_executor/models/qwen3_vl.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -984,14 +984,14 @@ def _get_mm_fields_config(
984984
image_embeds=MultiModalFieldConfig.flat_from_sizes(
985985
"image", image_grid_sizes
986986
),
987-
image_grid_thw=MultiModalFieldConfig.batched("image"),
987+
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
988988
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
989989
"video", video_grid_sizes
990990
),
991991
video_embeds=MultiModalFieldConfig.flat_from_sizes(
992992
"video", video_grid_sizes
993993
),
994-
video_grid_thw=MultiModalFieldConfig.batched("video"),
994+
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
995995
)
996996

997997
def _get_prompt_updates(
@@ -1190,8 +1190,6 @@ class Qwen3VLForConditionalGeneration(
11901190
SupportsMRoPE,
11911191
SupportsEagle3,
11921192
):
1193-
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1194-
11951193
packed_modules_mapping = {
11961194
"qkv_proj": [
11971195
"q_proj",

0 commit comments

Comments
 (0)