@@ -637,8 +637,14 @@ def _update_metadata(self,
637637 def compute_input_embeddings_for_mm_optimized (self , warmup_mode , ** kwargs ):
638638 input_ids = kwargs ['input_ids' ]
639639 vision_embeddings = self .model .get_multimodal_embeddings (** kwargs )
640- inputs_embeds = self .model .get_input_embeddings (
641- input_ids , vision_embeddings )
640+ if 'image_index' in kwargs :
641+ inputs_embeds = self .model .get_input_embeddings_hpu (
642+ input_ids , kwargs ['image_index' ], vision_embeddings )
643+ kwargs .pop ("image_index" , None )
644+ else :
645+ inputs_embeds = self .model .get_input_embeddings (
646+ input_ids , vision_embeddings )
647+
642648 # TODO: In warmup, we need to warmup the model with dummy image data for
643649 # multimodal model for prompt, here instead of generating a dummy image,
644650 # we are just generating attn_mask for the images and pass with
@@ -1772,6 +1778,7 @@ def _prepare_prompt(
17721778 pad = 0 ,
17731779 dtype = torch .long ,
17741780 flat = self .use_merged_prefill )
1781+ image_index_tensor = None
17751782 if self .model_is_mrope :
17761783 input_positions = \
17771784 make_mrope_positions_tensor_with_pad (input_positions = input_positions ,
@@ -1785,6 +1792,11 @@ def _prepare_prompt(
17851792 dtype = torch .long ,
17861793 flat = self .use_merged_prefill )
17871794
1795+ if seq_group_metadata .multi_modal_data and self .is_mm_optimized and \
1796+ 'InternVLChatModel' in str (type (self .model .model )):
1797+ is_image_flatten = (
1798+ input_tokens_tensor == self .image_token_id ).flatten ()
1799+ image_index_tensor = is_image_flatten .nonzero ().squeeze (- 1 )
17881800 slot_mapping = make_cpu_tensor (slot_mapping ,
17891801 max_len = max_prompt_len ,
17901802 pad = _PAD_SLOT_ID ,
@@ -1872,6 +1884,8 @@ def _prepare_prompt(
18721884 input_positions = input_positions ,
18731885 )
18741886 multi_modal_kwargs = MultiModalKwargs .batch (multi_modal_kwargs_list )
1887+ if image_index_tensor is not None :
1888+ multi_modal_kwargs ['image_index' ] = image_index_tensor
18751889 multi_modal_kwargs = MultiModalKwargs .as_kwargs (multi_modal_kwargs ,
18761890 device = self .device )
18771891
@@ -3872,6 +3886,12 @@ def try_revert_dummy_output_tokens():
38723886 ('pixel_values' )in model_input .multi_modal_kwargs ))
38733887 execute_model_kwargs ['attn_metadata' ] = attn_metadata
38743888
3889+ if 'image_index' in model_input .multi_modal_kwargs :
3890+ execute_model_kwargs [
3891+ 'image_index' ] = model_input .multi_modal_kwargs [
3892+ 'image_index' ]
3893+ model_input .multi_modal_kwargs .pop ('image_index' , None )
3894+
38753895 if not bypass_model_exec :
38763896 if self .model_is_mrope or self .is_mm_optimized :
38773897 if ('pixel_values' ) in execute_model_kwargs and \
0 commit comments