Skip to content

Commit 92e8db3

Browse files
committed
Fix for merge_multimodal_embeddedings() crash
1 parent e684eb5 commit 92e8db3

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

vllm/model_executor/models/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,18 @@ def merge_multimodal_embeddings(
553553
This updates ``inputs_embeds`` in place.
554554
"""
555555
if isinstance(placeholder_token_id, list):
556-
placeholder_token_id = torch.tensor(placeholder_token_id,
557-
device=input_ids.device)
556+
flat_ids = []
557+
for x in placeholder_token_id:
558+
if torch.is_tensor(x):
559+
flat_ids.extend(int(v) for v in x.reshape(-1).tolist())
560+
elif isinstance(x, (list, tuple)):
561+
flat_ids.extend(int(v) for v in x)
562+
else:
563+
flat_ids.append(int(x))
564+
565+
placeholder_token_id = torch.as_tensor(
566+
flat_ids, device=input_ids.device, dtype=input_ids.dtype
567+
)
558568
return _merge_multimodal_embeddings(
559569
inputs_embeds,
560570
torch.isin(input_ids, placeholder_token_id),

0 commit comments

Comments
 (0)