@@ -646,6 +646,37 @@ def msk(b, h, q_idx, kv_idx):
646646 ).masked_fill (~ tensor_mask , dtypemin )
647647 return tensor_mask
648648
649+ def _llm_or_vlm_embedding (self , input_ids , kwargs ):
650+ """Return input embeddings with possibly vision embeddings for VLM."""
651+ tok_embeds = self ._base_model_embeddings (input_ids )
652+
653+ # LLM only have token embeddings
654+ if "pixel_values" not in kwargs :
655+ return tok_embeds
656+
657+ # Otherwise, insert vision embeddings in tok_embeds
658+ if self .config .model_type == "NemotronH_Nano_VL_V2" :
659+ vit_embeds = self .extract_feature (kwargs ["pixel_values" ])
660+ vit_embeds = vit_embeds [kwargs ["image_flags" ] == 1 ]
661+ bs , seq_len , hid_size = tok_embeds .shape
662+ tok_embeds = tok_embeds .reshape (bs * seq_len , hid_size )
663+ input_ids = input_ids .reshape (bs * seq_len )
664+ selected = input_ids == self .img_context_token_id
665+ try :
666+ tok_embeds [selected ] = tok_embeds [selected ] * 0.0 + vit_embeds .reshape (- 1 , hid_size )
667+ except Exception as e :
668+ vit_embeds = vit_embeds .reshape (- 1 , hid_size )
669+ print (
670+ f"warning: { e } , tok_embeds[selected].shape={ tok_embeds [selected ].shape } , "
671+ f"vit_embeds.shape={ vit_embeds .shape } "
672+ )
673+ n_token = selected .sum ()
674+ tok_embeds [selected ] = tok_embeds [selected ] * 0.0 + vit_embeds [:n_token ]
675+ del vit_embeds
676+ return tok_embeds .reshape (bs , seq_len , hid_size )
677+ else :
678+ raise ValueError (f"VLM model type { self .config .model_type } not supported" )
679+
649680 def _base_model_forward (
650681 self ,
651682 input_ids ,
@@ -811,7 +842,8 @@ def forward(
811842 eagle_cache ,
812843 )
813844 with torch .no_grad ():
814- inputs_embeds = self ._base_model_embeddings (eagle_input_ids )
845+ inputs_embeds = self ._llm_or_vlm_embedding (eagle_input_ids , kwargs )
846+
815847 position_embeddings = self .eagle_rotary_emb (eagle_input_hidden_states , position_ids )
816848
817849 # Then, we run eagle forward
0 commit comments