@@ -386,19 +386,44 @@ def embed_multimodal(self, **kwargs):
386386 vision_embeddings = vision_embeddings .pooler_output
387387
388388 if isinstance (vision_embeddings , torch .Tensor ):
389- if vision_embeddings .ndim == 2 :
390- vision_embeddings = vision_embeddings .unsqueeze (0 )
391-
392- # Embeddings have to be 2D tensors of length `num_images`
393- # but transformers returns concat tensors if each patch
394- # is of different size. We split it back to make vLLM happy
395- vision_embeddings = torch .split (
396- vision_embeddings , num_image_patches .flatten ().tolist ()
397- )
398- vision_embeddings = [
399- embed .flatten (start_dim = 0 , end_dim = - 2 )
400- for embed in vision_embeddings
401- ]
389+ split_sizes = num_image_patches .flatten ().tolist ()
390+ total_patches = sum (split_sizes )
391+
392+ # Flatten to 2D: [total_tokens, hidden_dim]
393+ if vision_embeddings .ndim == 3 :
394+ vision_embeddings = vision_embeddings .view (
395+ - 1 , vision_embeddings .shape [- 1 ]
396+ )
397+
398+ total_tokens = vision_embeddings .shape [0 ]
399+ if total_tokens == total_patches :
400+ # Direct match: num_image_patches are actual token counts
401+ # (e.g., Qwen2.5-VL style)
402+ token_split_sizes = split_sizes
403+ elif total_patches > 0 and total_tokens % total_patches == 0 :
404+ # Uniform expansion: each patch expands to N tokens
405+ # (e.g., Idefics3 style)
406+ tokens_per_patch = total_tokens // total_patches
407+ token_split_sizes = [s * tokens_per_patch for s in split_sizes ]
408+ elif total_patches > 0 :
409+ # Mismatch (profiling with dummy data) - pad/truncate
410+ if total_tokens == 0 :
411+ raise ValueError (
412+ "Vision encoder returned empty embeddings. "
413+ f"Expected { total_patches } patches from "
414+ f"num_image_patches={ split_sizes } "
415+ )
416+ if total_tokens < total_patches :
417+ repeat_factor = (
418+ total_patches + total_tokens - 1
419+ ) // total_tokens
420+ vision_embeddings = vision_embeddings .repeat (repeat_factor , 1 )
421+ vision_embeddings = vision_embeddings [:total_patches ]
422+ token_split_sizes = split_sizes
423+ else :
424+ return []
425+
426+ return list (torch .split (vision_embeddings , token_split_sizes , dim = 0 ))
402427
403428 return vision_embeddings
404429 else :
0 commit comments