1616from transformers import (
1717 AutoConfig ,
1818 AutoImageProcessor ,
19+ AutoModelForCausalLM ,
1920 GenerationConfig ,
2021 GenerationMixin ,
2122 PretrainedConfig ,
3031from .configuration import OVConfig , OVWeightQuantizationConfig
3132from .modeling_base import OVBaseModel , OVModelPart
3233from .modeling_decoder import CausalLMOutputWithPast , OVModelForCausalLM
33- from .utils import TemporaryDirectory
34+ from .utils import (
35+ OV_LANGUAGE_MODEL_NAME ,
36+ OV_TEXT_EMBEDDINGS_MODEL_NAME ,
37+ OV_VISION_EMBEDDINGS_MODEL_NAME ,
38+ TemporaryDirectory ,
39+ )
40+
41+
42+ try :
43+ from transformers import LlavaForConditionalGeneration
44+ except ImportError :
45+ LlavaForConditionalGeneration = None
46+
47+ try :
48+ from transformers import LlavaNextForConditionalGeneration
49+ except ImportError :
50+ LlavaNextForConditionalGeneration = None
3451
3552
3653logger = logging .getLogger (__name__ )
@@ -67,13 +84,19 @@ def __init__(
6784 def compile (self ):
6885 if self .request is None :
6986 logger .info (f"Compiling the Language model to { self ._device } ..." )
70- self . request = core . compile_model ( self . model , self . _device , self . ov_config ). create_infer_request ()
87+ super (). compile ()
7188 self ._compile_text_emb ()
7289
7390 def _compile_text_emb (self ):
7491 if self .text_emb_request is None :
7592 logger .info (f"Compiling the Text embeddings model to { self ._device } ..." )
76- self .text_emb_request = core .compile_model (self .text_emb_model , self ._device , self .ov_config )
93+ if self ._compile_only :
94+ self .text_emb_request = self .text_emb_model
95+ else :
96+ logger .info (f"Compiling the Text embeddings model to { self ._device } ..." )
97+ self .text_emb_request = self ._compile_model (
98+ self .text_emb_model , self ._device , self .ov_config , self .model_save_dir
99+ )
77100
78101 def clear_requests (self ):
79102 if self ._compile_only :
@@ -238,12 +261,18 @@ def forward(self, img_features):
238261 return self .request (img_features )[0 ]
239262
240263
241- MODEL_PARTS_CLS_MAPPING = {"resampler" : OVResampler , "vision_projection" : OVVisionProjection }
264+ MODEL_PARTS_CLS_MAPPING = {
265+ "resampler" : OVResampler ,
266+ "language_model" : OVModelWithEmbedForCausalLM ,
267+ "vision_embeddings" : OVVisionEmbedding ,
268+ "vision_projection" : OVVisionProjection ,
269+ }
242270
243271
244272class OVModelForVisualCausalLM (OVBaseModel , GenerationMixin ):
245273 export_feature = "image-text-to-text"
246274 additional_parts = []
275+ auto_model_class = AutoModelForCausalLM
247276
248277 def __init__ (
249278 self ,
@@ -285,11 +314,11 @@ def __init__(
285314 self .lm_model ,
286315 self .text_embeddings_model ,
287316 config = config ,
288- deivce = device ,
317+ device = device ,
289318 ov_config = ov_config ,
290319 model_save_dir = model_save_dir ,
291320 quantization_config = quantization_config ,
292- compile = not self ._compile_only and enable_compilation ,
321+ compile = self ._compile_only or enable_compilation ,
293322 compile_only = self ._compile_only ,
294323 )
295324 self .vision_embeddings = OVVisionEmbedding (self .vision_embeddings_model , self )
@@ -315,19 +344,15 @@ def clear_requests(self):
315344 "`clear_requests()` is not supported with `compile_only` mode, please intialize model without this option"
316345 )
317346
318- self .language_model .clear_requests ()
319- components = [self .vision_embeddings ] + [getattr (self , part ) for part in self .additional_parts ]
320- for component in components :
321- if component is not None :
322- component .request = None
347+ for _ , component in self .components .items ():
348+ component .clear_requests ()
323349
324350 def compile (self ):
325- self .language_model .compile ()
326- self .vision_embeddings ._compile ()
327- for part in self .additional_parts :
328- part_model = getattr (self , part , None )
329- if part_model is not None :
330- part_model ._compile ()
351+ for _ , component in self .components .items ():
352+ if isinstance (component , OVModelPart ):
353+ component ._compile ()
354+ else :
355+ component .compile ()
331356
332357 def _save_config (self , save_directory ):
333358 """
@@ -345,21 +370,21 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
345370 save_directory (`str` or `Path`):
346371 The directory where to save the model files.
347372 """
348- src_files = [self .lm_model , self .text_embeddings_model , self .vision_embeddings_model ]
349- dst_file_names = [
350- "openvino_language_model.xml" ,
351- "openvino_text_embeddings_model.xml" ,
352- "openvino_vision_embeddings_model.xml" ,
353- ]
354- for part in self .additional_parts :
355- model = getattr (self , f"{ part } _model" , None )
356- if model is not None :
357- src_files .append (model )
358- dst_file_names .append (f"openvino_{ part } _model.xml" )
373+ src_models = self .submodels
374+ dst_file_names = {
375+ "lm_model" : OV_LANGUAGE_MODEL_NAME ,
376+ "text_embeddings_model" : OV_TEXT_EMBEDDINGS_MODEL_NAME ,
377+ "vision_embeddings_model" : OV_VISION_EMBEDDINGS_MODEL_NAME ,
378+ }
379+ for name in self ._submodel_names :
380+ if name not in dst_file_names :
381+ dst_file_names [name ] = f"openvino_{ name } .xml"
359382
360- for src_file , dst_file_name in zip (src_files , dst_file_names ):
383+ for name in self ._submodel_names :
384+ model = src_models [name ]
385+ dst_file_name = dst_file_names [name ]
361386 dst_path = os .path .join (save_directory , dst_file_name )
362- ov .save_model (src_file , dst_path , compress_to_fp16 = False )
387+ ov .save_model (model , dst_path , compress_to_fp16 = False )
363388
364389 self ._save_openvino_config (save_directory )
365390 if self .generation_config is not None :
@@ -429,14 +454,18 @@ def _from_pretrained(
429454 token = use_auth_token
430455
431456 model_file_names = {
432- "language_model" : "openvino_language_model.xml" ,
433- "text_embeddings" : "openvino_text_embeddings_model.xml" ,
434- "vision_embeddings" : "openvino_vision_embeddings_model.xml" ,
457+ "language_model" : OV_LANGUAGE_MODEL_NAME ,
458+ "language_model_bin" : OV_LANGUAGE_MODEL_NAME .replace (".xml" , ".bin" ),
459+ "text_embeddings" : OV_TEXT_EMBEDDINGS_MODEL_NAME ,
460+ "text_embeddings_bin" : OV_TEXT_EMBEDDINGS_MODEL_NAME .replace (".xml" , ".bin" ),
461+ "vision_embeddings" : OV_VISION_EMBEDDINGS_MODEL_NAME ,
462+ "vision_embeddings_bin" : OV_VISION_EMBEDDINGS_MODEL_NAME .replace (".xml" , ".bin" ),
435463 }
436464
437465 model_cls = MODEL_TYPE_TO_CLS_MAPPING [config .model_type ]
438466 for part in model_cls .additional_parts :
439467 model_file_names [part ] = f"openvino_{ part } _model.xml"
468+ model_file_names [part + "_bin" ] = f"openvino_{ part } _model.bin"
440469 compile_only = kwargs .get ("compile_only" , False )
441470 if os .path .isdir (model_id ):
442471 # Load model from a local directory
@@ -593,6 +622,28 @@ def _from_transformers(
593622 ** kwargs ,
594623 )
595624
625+ @property
626+ def _component_names (self ):
627+ base_components = ["language_model" , "vision_embeddings" ]
628+ additional_components = [part for part in self .additional_parts if getattr (self , part , None ) is not None ]
629+ return base_components + additional_components
630+
631+ @property
632+ def components (self ):
633+ return {component_name : getattr (self , component_name ) for component_name in self ._component_names }
634+
635+ @property
636+ def _submodel_names (self ):
637+ model_names = ["lm_model" , "text_embeddings_model" , "vision_embeddings_model" ]
638+ for part in self .additional_parts :
639+ if getattr (self , part , None ) is not None :
640+ model_names .append (part + "_model" )
641+ return model_names
642+
643+ @property
644+ def submodels (self ):
645+ return {submodel_name : getattr (self , submodel_name ) for submodel_name in self ._submodel_names }
646+
596647 def reshape (self , batch_size : int , sequence_length : int ):
597648 logger .warning ("Static shapes are not supported for causal language model." )
598649 return self
@@ -601,17 +652,14 @@ def half(self):
601652 """
602653 Converts all the model weights to FP16 for more efficient inference on GPU.
603654 """
604- apply_moc_transformations (self .lm_model , cf = False )
605- compress_model_transformation (self .lm_model )
606- apply_moc_transformations (self .text_embeddings_model , cf = False )
607- compress_model_transformation (self .text_embeddings_model )
608- apply_moc_transformations (self .vision_embeddings_model , cf = False )
609- compress_model_transformation (self .vision_embeddings_model )
610- for part in self .additional_parts :
611- model = getattr (self , f"{ part } _model" , None )
612- if model is not None :
613- apply_moc_transformations (model , cf = False )
614- compress_model_transformation (model )
655+ for _ , submodel in self .submodels .items ():
656+ apply_moc_transformations (submodel , cf = False )
657+ compress_model_transformation (submodel )
658+ return self
659+
660+ def to (self , device ):
661+ self .language_model .to (device )
662+ super ().to (device )
615663 return self
616664
617665 def forward (
@@ -625,11 +673,8 @@ def forward(
625673 position_ids = None ,
626674 image_bound = None ,
627675 tgt_sizes = None ,
628- images = None ,
629676 ** kwargs ,
630677 ):
631- if pixel_values is None and images is not None :
632- pixel_values = images
633678 inputs_embeds , attention_mask , position_ids = self .get_multimodal_embeddings (
634679 input_ids ,
635680 pixel_values ,
@@ -733,7 +778,6 @@ def prepare_inputs_for_generation(
733778 "image_sizes" : image_sizes ,
734779 "image_bound" : kwargs .get ("image_bound" ),
735780 "tgt_sizes" : kwargs .get ("tgt_sizes" ),
736- "images" : kwargs .get ("images" ),
737781 }
738782 )
739783 return model_inputs
@@ -756,6 +800,8 @@ def preprocess_inputs(
756800
757801
758802class _OVLlavaForCausalLM (OVModelForVisualCausalLM ):
803+ auto_model_class = LlavaForConditionalGeneration
804+
759805 def __init__ (
760806 self ,
761807 language_model : ov .Model ,
@@ -941,6 +987,8 @@ def preprocess_inputs(
941987
942988
943989class _OVLlavaNextForCausalLM (_OVLlavaForCausalLM ):
990+ auto_model_class = LlavaNextForConditionalGeneration
991+
944992 # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L655
945993 def pack_image_features (self , image_features , image_sizes , image_newline = None ):
946994 from transformers .models .llava_next .modeling_llava_next import get_anyres_image_grid_shape , unpad_image
@@ -1211,7 +1259,7 @@ def get_text_embeddings(self, input_ids, **kwargs):
12111259 return super ().get_text_embeddings (for_inputs_embeds_ids , ** kwargs )
12121260
12131261
1214- class _OvInternVLForCausalLM (OVModelForVisualCausalLM ):
1262+ class _OVInternVLForCausalLM (OVModelForVisualCausalLM ):
12151263 def get_vision_embeddings (self , pixel_values , input_ids = None , ** kwargs ):
12161264 if input_ids is not None and input_ids .shape [1 ] == 1 :
12171265 return None
@@ -1822,7 +1870,7 @@ def preprocess_inputs(
18221870 attention_mask = torch .ones_like (input_ids , dtype = torch .int64 )
18231871 result = {"input_ids" : input_ids , "attention_mask" : attention_mask }
18241872 if image is not None :
1825- result ["images " ] = torch . unsqueeze ( processor (images = image , return_tensors = "pt" )["pixel_values" ][ 0 ], 0 )
1873+ result ["pixel_values " ] = processor (images = [ image ] , return_tensors = "pt" )["pixel_values" ]
18261874 return result
18271875
18281876
@@ -1979,8 +2027,8 @@ def preprocess_inputs(
19792027MODEL_TYPE_TO_CLS_MAPPING = {
19802028 "llava" : _OVLlavaForCausalLM ,
19812029 "llava_next" : _OVLlavaNextForCausalLM ,
1982- "internvl_chat" : _OvInternVLForCausalLM ,
19832030 "minicpmv" : _OVMiniCPMVForCausalLM ,
19842031 "llava-qwen2" : _OVNanoLlavaForCausalLM ,
19852032 "phi3_v" : _OVPhi3VisionForCausalLM ,
2033+ "internvl_chat" : _OVInternVLForCausalLM ,
19862034}
0 commit comments