@@ -697,6 +697,33 @@ def can_generate(self):
697697
698698
699699class _OVLlavaForCausalLM (OVModelForVisualCausalLM ):
700+ def __init__ (
701+ self ,
702+ language_model : ov .Model ,
703+ text_embeddings : ov .Model ,
704+ vision_embeddings : ov .Model ,
705+ config : PretrainedConfig = None ,
706+ device : str = "CPU" ,
707+ dynamic_shapes : bool = True ,
708+ ov_config : Optional [Dict [str , str ]] = None ,
709+ model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
710+ quantization_config : Union [OVWeightQuantizationConfig , Dict ] = None ,
711+ ** kwargs ,
712+ ):
713+ super ().__init__ (
714+ language_model = language_model ,
715+ text_embeddings = text_embeddings ,
716+ vision_embeddings = vision_embeddings ,
717+ config = config ,
718+ device = device ,
719+ dynamic_shapes = dynamic_shapes ,
720+ ov_config = ov_config ,
721+ model_save_dir = model_save_dir ,
722+ quantization_config = quantization_config ,
723+ ** kwargs ,
724+ )
725+ self ._support_new_processing = hasattr (self .config , "image_seq_length" )
726+
700727 def get_vision_embeddings (self , pixel_values , input_ids = None , ** kwargs ):
701728 if input_ids is not None and input_ids .shape [1 ] == 1 :
702729 return None
@@ -725,17 +752,11 @@ def merge_vision_text_embeddings(
725752 input_ids ,
726753 attention_mask ,
727754 position_ids = None ,
728- legacy_processing = None ,
755+ legacy_processing = False ,
729756 ** kwargs ,
730757 ):
731758 image_features = torch .from_numpy (vision_embeds ) if isinstance (vision_embeds , np .ndarray ) else vision_embeds
732759 inputs_embeds = torch .from_numpy (inputs_embeds ) if isinstance (inputs_embeds , np .ndarray ) else inputs_embeds
733- if legacy_processing is None :
734- legacy_processing = (
735- not hasattr (self .config , "image_seq_length" )
736- or ((input_ids == self .config .image_token_index ).sum (1 ).max () < self .config .image_seq_length )
737- or (input_ids .shape [- 1 ] == 1 )
738- )
739760
740761 if legacy_processing :
741762 pad_token_id = self .config .pad_token_id if self .config .pad_token_id is not None else - 1
@@ -768,15 +789,6 @@ def merge_vision_text_embeddings(
768789 final_attention_mask = torch .zeros (
769790 batch_size , max_embed_dim , dtype = attention_mask .dtype , device = inputs_embeds .device
770791 )
771- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
772- # set the corresponding tensors into their correct target device.
773- target_device = inputs_embeds .device
774- batch_indices , non_image_indices , text_to_overwrite = (
775- batch_indices .to (target_device ),
776- non_image_indices .to (target_device ),
777- text_to_overwrite .to (target_device ),
778- )
779- attention_mask = attention_mask .to (target_device )
780792
781793 # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
782794 # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
@@ -787,15 +799,15 @@ def merge_vision_text_embeddings(
787799 (batch_size , max_embed_dim ), True , dtype = torch .bool , device = inputs_embeds .device
788800 )
789801 image_to_overwrite [batch_indices , text_to_overwrite ] = False
790- image_to_overwrite &= image_to_overwrite .cumsum (- 1 ) - 1 >= nb_image_pad [:, None ]. to ( target_device )
802+ image_to_overwrite &= image_to_overwrite .cumsum (- 1 ) - 1 >= nb_image_pad [:, None ]
791803
792804 if image_to_overwrite .sum () != image_features .shape [:- 1 ].numel ():
793805 raise ValueError (
794806 f"The input provided to the model a/pre-releasesre wrong. The number of image tokens is { torch .sum (special_image_token_mask )} while"
795807 f" the number of image given to the model is { num_images } . This prevents correct indexing and breaks batch generation."
796808 )
797809
798- final_embedding [image_to_overwrite ] = image_features .contiguous ().reshape (- 1 , embed_dim ). to ( target_device )
810+ final_embedding [image_to_overwrite ] = image_features .contiguous ().reshape (- 1 , embed_dim )
799811 final_attention_mask |= image_to_overwrite
800812 position_ids = (final_attention_mask .cumsum (- 1 ) - 1 ).masked_fill_ ((final_attention_mask == 0 ), 1 )
801813
@@ -815,11 +827,12 @@ def merge_vision_text_embeddings(
815827 def get_multimodal_embeddings (
816828 self , input_ids , pixel_values = None , attention_mask = None , position_ids = None , past_key_values = None , ** kwargs
817829 ):
818- legacy_processing = (
819- not hasattr (self .config , "image_seq_length" )
820- or ((input_ids == self .config .image_token_index ).sum (1 ).max () < self .config .image_seq_length )
821- or (input_ids .shape [- 1 ] == 1 and pixel_values is not None )
822- )
830+ if pixel_values is not None and self ._support_new_processing and past_key_values is None :
831+ legacy_processing = (input_ids == self .config .image_token_index ).sum (
832+ 1
833+ ).max () < self .config .image_seq_length
834+ else :
835+ legacy_processing = True
823836 inputs_embeds , attention_mask , position_ids = super ().get_multimodal_embeddings (
824837 input_ids , pixel_values , attention_mask , position_ids , legacy_processing = legacy_processing , ** kwargs
825838 )
@@ -830,38 +843,19 @@ def get_multimodal_embeddings(
830843 return inputs_embeds , attention_mask , position_ids
831844
832845 def _filter_unattended_tokens (self , input_ids , attention_mask , past_key_values ):
833- if not self .language_model .stateful :
834- first_layer_past_key_value = torch .from_numpy (past_key_values [0 ][0 ][:, :, :, 0 ])
835- else :
836- first_layer_past_key_value = torch .from_numpy (
837- self .language_model .request .query_state ()[0 ].state .data [:, :, :, 0 ]
838- )
839-
840- # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
841- batch_index , non_attended_tokens = torch .where (first_layer_past_key_value .float ().sum (- 2 ) == 0 )
842-
843846 # Get the target length
844847 target_length = input_ids .shape [1 ]
845- past_length = first_layer_past_key_value . shape [ - 1 ]
848+ past_length = self . language_model . _get_past_length ( past_key_values )
846849
847850 extended_attention_mask = torch .ones (
848851 (attention_mask .shape [0 ], past_length ),
849852 dtype = attention_mask .dtype ,
850853 device = attention_mask .device ,
851854 )
852855
853- # Filter out only the tokens that can be un-attended, this can happen
854- # if one uses Llava + Fused modules where the cache on the
855- # first iteration is already big enough, or if one passes custom cache
856- valid_indices = non_attended_tokens < extended_attention_mask .size (- 1 )
857- new_batch_index = batch_index [valid_indices ]
858- new_non_attended_tokens = non_attended_tokens [valid_indices ]
859-
860- # Zero-out the places where we don't need to attend
861- extended_attention_mask [new_batch_index , new_non_attended_tokens ] = 0
862-
863856 attention_mask = torch .cat ((extended_attention_mask , attention_mask [:, - target_length :]), dim = 1 )
864- position_ids = torch .sum (attention_mask , dim = 1 ).unsqueeze (- 1 ) - 1
857+ position_ids = torch .cumsum (attention_mask , axis = 1 ) - 1
858+ position_ids [attention_mask == 0 ] = 1
865859 return attention_mask , position_ids
866860
867861
@@ -938,11 +932,13 @@ def get_multimodal_embeddings(
938932
939933 inputs_embeds = self .get_text_embeddings (input_ids , ** kwargs )
940934
941- legacy_processing = (
942- not hasattr (self .config , "image_seq_length" )
943- or ((input_ids == self .config .image_token_index ).sum (1 ).max () < self .config .image_seq_length )
944- or (input_ids .shape [- 1 ] == 1 and pixel_values is not None )
945- )
935+ if pixel_values is not None and self ._support_new_processing and past_key_values is None :
936+ legacy_processing = (input_ids == self .config .image_token_index ).sum (
937+ 1
938+ ).max () < self .config .image_seq_length
939+ else :
940+ legacy_processing = True
941+
946942 if pixel_values is not None and pixel_values .size (0 ) > 0 :
947943 # ! infer image_num_patches from image_sizes
948944 image_num_patches = [
@@ -996,7 +992,7 @@ def merge_vision_text_embeddings(
996992 input_ids ,
997993 attention_mask ,
998994 position_ids = None ,
999- legacy_processing = None ,
995+ legacy_processing = False ,
1000996 ** kwargs ,
1001997 ):
1002998 image_token_index = self .config .image_token_index
0 commit comments