Skip to content

Commit 22aeb43

Browse files
[Bugfix][VLM] Fix transformers backend embed_multimodal for Qwen2.5-VL profiling (vllm-project#32969)
Signed-off-by: Andreas Karatzas <[email protected]>
1 parent a698e8e commit 22aeb43

File tree

1 file changed

+38
-13
lines changed

1 file changed

+38
-13
lines changed

vllm/model_executor/models/transformers/multimodal.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -386,19 +386,44 @@ def embed_multimodal(self, **kwargs):
386386
vision_embeddings = vision_embeddings.pooler_output
387387

388388
if isinstance(vision_embeddings, torch.Tensor):
389-
if vision_embeddings.ndim == 2:
390-
vision_embeddings = vision_embeddings.unsqueeze(0)
391-
392-
# Embeddings have to be 2D tensors of length `num_images`
393-
# but transformers returns concat tensors if each patch
394-
# is of different size. We split it back to make vLLM happy
395-
vision_embeddings = torch.split(
396-
vision_embeddings, num_image_patches.flatten().tolist()
397-
)
398-
vision_embeddings = [
399-
embed.flatten(start_dim=0, end_dim=-2)
400-
for embed in vision_embeddings
401-
]
389+
split_sizes = num_image_patches.flatten().tolist()
390+
total_patches = sum(split_sizes)
391+
392+
# Flatten to 2D: [total_tokens, hidden_dim]
393+
if vision_embeddings.ndim == 3:
394+
vision_embeddings = vision_embeddings.view(
395+
-1, vision_embeddings.shape[-1]
396+
)
397+
398+
total_tokens = vision_embeddings.shape[0]
399+
if total_tokens == total_patches:
400+
# Direct match: num_image_patches are actual token counts
401+
# (e.g., Qwen2.5-VL style)
402+
token_split_sizes = split_sizes
403+
elif total_patches > 0 and total_tokens % total_patches == 0:
404+
# Uniform expansion: each patch expands to N tokens
405+
# (e.g., Idefics3 style)
406+
tokens_per_patch = total_tokens // total_patches
407+
token_split_sizes = [s * tokens_per_patch for s in split_sizes]
408+
elif total_patches > 0:
409+
# Mismatch (profiling with dummy data) - pad/truncate
410+
if total_tokens == 0:
411+
raise ValueError(
412+
"Vision encoder returned empty embeddings. "
413+
f"Expected {total_patches} patches from "
414+
f"num_image_patches={split_sizes}"
415+
)
416+
if total_tokens < total_patches:
417+
repeat_factor = (
418+
total_patches + total_tokens - 1
419+
) // total_tokens
420+
vision_embeddings = vision_embeddings.repeat(repeat_factor, 1)
421+
vision_embeddings = vision_embeddings[:total_patches]
422+
token_split_sizes = split_sizes
423+
else:
424+
return []
425+
426+
return list(torch.split(vision_embeddings, token_split_sizes, dim=0))
402427

403428
return vision_embeddings
404429
else:

0 commit comments

Comments
 (0)