1+ import copy
12import logging
23import os
34import warnings
5+ from abc import abstractmethod
46from pathlib import Path
57from typing import Dict , Optional , Tuple , Union
68
1012from huggingface_hub import hf_hub_download
1113from huggingface_hub .constants import HUGGINGFACE_HUB_CACHE
1214from openvino ._offline_transformations import apply_moc_transformations , compress_model_transformation
13- from transformers import AutoConfig , GenerationConfig , GenerationMixin , PretrainedConfig
15+ from PIL .Image import Image
16+ from transformers import (
17+ AutoConfig ,
18+ GenerationConfig ,
19+ GenerationMixin ,
20+ PretrainedConfig ,
21+ PreTrainedTokenizer ,
22+ )
1423from transformers .modeling_outputs import BaseModelOutputWithPooling
1524
1625from ...exporters .openvino import main_export
1726from ...exporters .openvino .stateful import ensure_stateful_is_available , model_has_input_output_name
27+ from .. import OVQuantizer
1828from .configuration import OVConfig , OVWeightQuantizationConfig
1929from .modeling_base import OVBaseModel , OVModelPart
2030from .modeling_decoder import CausalLMOutputWithPast , OVModelForCausalLM
@@ -181,6 +191,7 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
181191 self ._main_input = "images" if model_has_input_output_name (self .model , "images" ) else "pixel_values"
182192
183193 def forward (self , pixel_values , ** kwargs ):
194+ self ._compile ()
184195 inputs = {self ._main_input : pixel_values }
185196 if len (self .input_names ) > 1 :
186197 for name in self .input_names :
@@ -210,6 +221,7 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
210221 self .output_names = {key .get_any_name (): idx for idx , key in enumerate (self .model .outputs )}
211222
212223 def forward (self , image_feature , pos_embed , key_padding_mask ):
224+ self ._compile ()
213225 result = self .request (
214226 {"image_feature" : image_feature , "pos_embed" : pos_embed , "key_padding_mask" : key_padding_mask }
215227 )[0 ]
@@ -244,7 +256,7 @@ def __init__(
244256 self .ov_config = {} if ov_config is None else {** ov_config }
245257 self .preprocessors = kwargs .get ("preprocessors" , [])
246258 self .lm_model = language_model
247- self .text_embdings_model = text_embeddings
259+ self .text_embeddings_model = text_embeddings
248260 self .vision_embeddings_model = vision_embeddings
249261 self ._supports_cache_class = False
250262 self .main_input_name = "input_ids"
@@ -261,13 +273,13 @@ def __init__(
261273 self ._set_ov_config_parameters ()
262274 self .language_model = OVModelWithEmbedForCausalLM (
263275 self .lm_model ,
264- self .text_embdings_model ,
276+ self .text_embeddings_model ,
265277 config = config ,
266278 deivce = device ,
267279 ov_config = ov_config ,
268280 model_save_dir = model_save_dir ,
269281 quantization_config = quantization_config ,
270- compile = not self ._compile_only ,
282+ compile = not self ._compile_only and enable_compilation ,
271283 compile_only = self ._compile_only ,
272284 )
273285 self .vision_embeddings = OVVisionEmbedding (self .vision_embeddings_model , self )
@@ -287,6 +299,18 @@ def __init__(
287299 except AttributeError :
288300 pass
289301
302+ def clear_requests (self ):
303+ if self ._compile_only :
304+ raise ValueError (
305+ "`clear_requests()` is not supported with `compile_only` mode, please intialize model without this option"
306+ )
307+
308+ self .language_model .clear_requests ()
309+ components = [self .vision_embeddings ] + [getattr (self , part ) for part in self .additional_parts ]
310+ for component in components :
311+ if component is not None :
312+ component .request = None
313+
290314 def compile (self ):
291315 self .language_model .compile ()
292316 self .vision_embeddings ._compile ()
@@ -304,11 +328,11 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
304328 save_directory (`str` or `Path`):
305329 The directory where to save the model files.
306330 """
307- src_files = [self .lm_model , self .text_embdings_model , self .vision_embeddings_model ]
331+ src_files = [self .lm_model , self .text_embeddings_model , self .vision_embeddings_model ]
308332 dst_file_names = [
309333 "openvino_language_model.xml" ,
310334 "openvino_text_embeddings_model.xml" ,
311- "openvino_vision_embeddings .xml" ,
335+ "openvino_vision_embeddings_model .xml" ,
312336 ]
313337 for part in self .additional_parts :
314338 model = getattr (self , f"{ part } _model" , None )
@@ -387,26 +411,18 @@ def _from_pretrained(
387411 raise ValueError ("You cannot use both `use_auth_token` and `token` arguments at the same time." )
388412 token = use_auth_token
389413
390- model_cls = MODEL_TYPE_TO_CLS_MAPPING [config .model_type ]
391-
392- quantization_config = model_cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
393- compile_only = kwargs .get ("compile_only" , False )
394-
395- # Load model from a local directory
396- if os .path .isdir (model_id ):
397- model_save_dir = Path (model_id )
398414 model_file_names = {
399415 "language_model" : "openvino_language_model.xml" ,
400416 "text_embeddings" : "openvino_text_embeddings_model.xml" ,
401417 "vision_embeddings" : "openvino_vision_embeddings_model.xml" ,
402418 }
403419
420+ model_cls = MODEL_TYPE_TO_CLS_MAPPING [config .model_type ]
404421 for part in model_cls .additional_parts :
405422 model_file_names [part ] = f"openvino_{ part } _model.xml"
406- model_cls = MODEL_TYPE_TO_CLS_MAPPING [config .model_type ]
407- quantization_config = model_cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
408423 compile_only = kwargs .get ("compile_only" , False )
409424 if os .path .isdir (model_id ):
425+ # Load model from a local directory
410426 model_save_dir = Path (model_id )
411427 file_names = {k : os .path .join (model_id , model_file_names [k ]) for k in model_file_names }
412428 else :
@@ -424,11 +440,11 @@ def _from_pretrained(
424440 file_names [name ] = model_cache_path
425441 model_save_dir = Path (model_cache_path ).parent
426442 if not compile_only :
427- language_model = model_cls .load_model (file_names ["language_model" ], quantization_config )
428- text_embeddings = model_cls .load_model (file_names ["text_embeddings" ], quantization_config )
429- vision_embeddings = model_cls .load_model (file_names ["vision_embeddings" ], quantization_config )
443+ language_model = model_cls .load_model (file_names ["language_model" ])
444+ text_embeddings = model_cls .load_model (file_names ["text_embeddings" ])
445+ vision_embeddings = model_cls .load_model (file_names ["vision_embeddings" ])
430446 for part in model_cls .additional_parts :
431- kwargs [part ] = model_cls .load_model (file_names [part ], quantization_config )
447+ kwargs [part ] = model_cls .load_model (file_names [part ])
432448 else :
433449 language_model = model_cls ._compile_model (
434450 file_names ["language_model" ],
@@ -468,7 +484,12 @@ def _from_pretrained(
468484 except Exception :
469485 pass
470486
471- return model_cls (
487+ quantization_config = model_cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
488+ to_quantize = not compile_only and quantization_config is not None
489+ if to_quantize :
490+ kwargs ["compile" ] = False
491+
492+ model = model_cls (
472493 language_model = language_model ,
473494 text_embeddings = text_embeddings ,
474495 vision_embeddings = vision_embeddings ,
@@ -478,6 +499,15 @@ def _from_pretrained(
478499 ** kwargs ,
479500 )
480501
502+ if to_quantize :
503+ quantization_config_copy = copy .deepcopy (quantization_config )
504+ quantization_config_copy .tokenizer = quantization_config .tokenizer or model_id
505+ potential_processor_id = config .mm_vision_tower if isinstance (model , _OVNanoLlavaForCausalLM ) else model_id
506+ quantization_config_copy .processor = quantization_config .processor or potential_processor_id
507+ OVQuantizer (model ).quantize (ov_config = OVConfig (quantization_config = quantization_config_copy ))
508+
509+ return model
510+
481511 @classmethod
482512 def _from_transformers (
483513 cls ,
@@ -556,8 +586,8 @@ def half(self):
556586 """
557587 apply_moc_transformations (self .lm_model , cf = False )
558588 compress_model_transformation (self .lm_model )
559- apply_moc_transformations (self .text_embdings_model , cf = False )
560- compress_model_transformation (self .text_embdings_model )
589+ apply_moc_transformations (self .text_embeddings_model , cf = False )
590+ compress_model_transformation (self .text_embeddings_model )
561591 apply_moc_transformations (self .vision_embeddings_model , cf = False )
562592 compress_model_transformation (self .vision_embeddings_model )
563593 for part in self .additional_parts :
@@ -695,6 +725,18 @@ def can_generate(self):
695725 """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
696726 return True
697727
728+ @staticmethod
729+ @abstractmethod
730+ def preprocess_inputs (
731+ processor ,
732+ text : str ,
733+ image : Optional [Image ] = None ,
734+ tokenizer : Optional [PreTrainedTokenizer ] = None ,
735+ ):
736+ """
737+ Preprocess input instruction and an image.
738+ """
739+
698740
699741class _OVLlavaForCausalLM (OVModelForVisualCausalLM ):
700742 def __init__ (
@@ -858,6 +900,20 @@ def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):
858900 position_ids [attention_mask == 0 ] = 1
859901 return attention_mask , position_ids
860902
903+ @staticmethod
904+ def preprocess_inputs (
905+ processor ,
906+ text : str ,
907+ image : Optional [Image ] = None ,
908+ tokenizer : Optional [PreTrainedTokenizer ] = None ,
909+ ):
910+ if image is None :
911+ raise ValueError ("Image is required." )
912+ chat_template = [{"role" : "user" , "content" : [{"type" : "text" , "text" : text }, {"type" : "image" }]}]
913+ prompt = processor .apply_chat_template (chat_template , add_generation_prompt = True )
914+ inputs = processor (images = image , text = prompt , return_tensors = "pt" )
915+ return inputs
916+
861917
862918class _OVLlavaNextForCausalLM (_OVLlavaForCausalLM ):
863919 # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L655
@@ -1372,6 +1428,19 @@ def merge_vision_text_embeddings(
13721428 )
13731429 return vllm_embedding , attention_mask , position_ids
13741430
1431+ @staticmethod
1432+ def preprocess_inputs (
1433+ processor ,
1434+ text : str ,
1435+ image : Optional [Image ] = None ,
1436+ tokenizer : Optional [PreTrainedTokenizer ] = None ,
1437+ ):
1438+ if image is None :
1439+ raise ValueError ("Image is required." )
1440+ prompt = f"<|im_start|>user\n (<image>./</image>)\n { text } <|im_end|>\n <|im_start|>assistant\n "
1441+ inputs = processor ([prompt ], [image ], return_tensors = "pt" )
1442+ return inputs
1443+
13751444
13761445class _OVNanoLlavaForCausalLM (OVModelForVisualCausalLM ):
13771446 def get_vision_embeddings (self , pixel_values , input_ids = None , ** kwargs ):
@@ -1544,6 +1613,25 @@ def get_multimodal_embeddings(
15441613
15451614 return new_input_embeds , attention_mask , position_ids
15461615
1616+ @staticmethod
1617+ def preprocess_inputs (
1618+ processor ,
1619+ text : str ,
1620+ image : Optional [Image ] = None ,
1621+ tokenizer : Optional [PreTrainedTokenizer ] = None ,
1622+ ):
1623+ if tokenizer is None :
1624+ raise ValueError ("Tokenizer is required." )
1625+ messages = [{"role" : "user" , "content" : f"<image>\n { text } " }]
1626+ text = tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
1627+ text_chunks = [tokenizer (chunk ).input_ids for chunk in text .split ("<image>" )]
1628+ input_ids = torch .tensor (text_chunks [0 ] + [- 200 ] + text_chunks [1 ], dtype = torch .long ).unsqueeze (0 )
1629+ attention_mask = torch .ones_like (input_ids , dtype = torch .int64 )
1630+ result = {"input_ids" : input_ids , "attention_mask" : attention_mask }
1631+ if image is not None :
1632+ result ["images" ] = torch .unsqueeze (processor (images = image , return_tensors = "pt" )["pixel_values" ][0 ], 0 )
1633+ return result
1634+
15471635
15481636MODEL_TYPE_TO_CLS_MAPPING = {
15491637 "llava" : _OVLlavaForCausalLM ,
0 commit comments