Skip to content

Commit 0e98540

Browse files
committed
add vision embedding in eagle input
Signed-off-by: h-guo18 <[email protected]>
1 parent 96d71ae commit 0e98540

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,37 @@ def msk(b, h, q_idx, kv_idx):
646646
).masked_fill(~tensor_mask, dtypemin)
647647
return tensor_mask
648648

649+
def _llm_or_vlm_embedding(self, input_ids, kwargs):
650+
"""Return input embeddings with possibly vision embeddings for VLM."""
651+
tok_embeds = self._base_model_embeddings(input_ids)
652+
653+
# LLM only have token embeddings
654+
if "pixel_values" not in kwargs:
655+
return tok_embeds
656+
657+
# Otherwise, insert vision embeddings in tok_embeds
658+
if self.config.model_type == "NemotronH_Nano_VL_V2":
659+
vit_embeds = self.extract_feature(kwargs["pixel_values"])
660+
vit_embeds = vit_embeds[kwargs["image_flags"] == 1]
661+
bs, seq_len, hid_size = tok_embeds.shape
662+
tok_embeds = tok_embeds.reshape(bs * seq_len, hid_size)
663+
input_ids = input_ids.reshape(bs * seq_len)
664+
selected = input_ids == self.img_context_token_id
665+
try:
666+
tok_embeds[selected] = tok_embeds[selected] * 0.0 + vit_embeds.reshape(-1, hid_size)
667+
except Exception as e:
668+
vit_embeds = vit_embeds.reshape(-1, hid_size)
669+
print(
670+
f"warning: {e}, tok_embeds[selected].shape={tok_embeds[selected].shape}, "
671+
f"vit_embeds.shape={vit_embeds.shape}"
672+
)
673+
n_token = selected.sum()
674+
tok_embeds[selected] = tok_embeds[selected] * 0.0 + vit_embeds[:n_token]
675+
del vit_embeds
676+
return tok_embeds.reshape(bs, seq_len, hid_size)
677+
else:
678+
raise ValueError(f"VLM model type {self.config.model_type} not supported")
679+
649680
def _base_model_forward(
650681
self,
651682
input_ids,
@@ -811,7 +842,8 @@ def forward(
811842
eagle_cache,
812843
)
813844
with torch.no_grad():
814-
inputs_embeds = self._base_model_embeddings(eagle_input_ids)
845+
inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs)
846+
815847
position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids)
816848

817849
# Then, we run eagle forward

0 commit comments

Comments
 (0)