@@ -287,12 +287,44 @@ def _get_llama_prompt_embeds(
287287 prompt_attention_mask = text_inputs .attention_mask .to (device = device )
288288
289289 image_embeds = self .image_processor (image , return_tensors = "pt" ).pixel_values .to (device )
290+
290291 _ , _ , image_height , image_width = image_embeds .shape
291292 patch_size = self .text_encoder .config .vision_config .patch_size
292293 num_image_tokens = (image_height // patch_size ) * (image_width // patch_size )
293294 if self .text_encoder .config .vision_config .vision_feature_select_strategy == "default" :
294295 num_image_tokens -= 1
295296
297+ image_token_index = self .text_encoder .config .image_token_index
298+ pad_token_id = self .text_encoder .config .pad_token_id
299+ batch_indices , non_image_indices = torch .where (text_input_ids != image_token_index )
300+
301+ special_image_token_mask = text_input_ids == image_token_index
302+ num_special_image_tokens = torch .sum (special_image_token_mask , dim = - 1 )
303+
304+ max_expanded_length = max_sequence_length + (
305+ num_special_image_tokens .max () * (prompt_template ["image_emb_len" ] - 1 )
306+ )
307+ new_token_positions = (
308+ torch .cumsum ((special_image_token_mask * (prompt_template ["image_emb_len" ] - 1 ) + 1 ), - 1 ) - 1
309+ )
310+ nb_image_pad = max_expanded_length - 1 - new_token_positions [:, - 1 ]
311+ if left_padding :
312+ new_token_positions += nb_image_pad [:, None ]
313+
314+ text_to_overwrite = new_token_positions [batch_indices , non_image_indices ]
315+
316+ expanded_input_ids = torch .full (
317+ (batch_size , max_expanded_length ), pad_token_id , dtype = text_input_ids .dtype , device = device
318+ )
319+ expanded_attention_mask = torch .ones (
320+ (batch_size , max_expanded_length ), dtype = prompt_attention_mask .dtype , device = device
321+ )
322+
323+ expanded_input_ids [batch_indices , text_to_overwrite ] = text_input_ids [batch_indices , non_image_indices ]
324+ expanded_inputs_ids [batch_indices , prompt_template ["image_emb_start" ] : prompt_template ["image_emb_end" ]] = (
325+ image_token_index
326+ )
327+
296328 inputs = self .llava_processor (
297329 text = prompt ,
298330 images = image ,
0 commit comments