1414from transformers .models .sam .modeling_sam import SamImageSegmentationOutput , SamPositionalEmbedding
1515
1616from ...exporters .openvino .utils import save_config
17+ from .. import OVConfig
18+ from .configuration import OVQuantizationConfigBase
1719from .modeling_base import OVBaseModel , OVModelPart
1820from .utils import (
1921 ONNX_PROMPT_ENCODER_MASK_DECODER_MODEL_NAME ,
@@ -83,19 +85,23 @@ def __init__(
8385 dynamic_shapes : bool = True ,
8486 ov_config : Optional [Dict [str , str ]] = None ,
8587 model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
88+ quantization_config : Union [OVQuantizationConfigBase , Dict ] = None ,
8689 ** kwargs ,
8790 ):
8891 self .config = config
8992 self ._model_save_dir = model_save_dir
9093 self ._device = device .upper ()
9194 self .ov_config = {} if ov_config is None else {** ov_config }
9295 self .preprocessors = kwargs .get ("preprocessors" , [])
93- self .vision_encoder_model = vision_encoder_model
94- self .prompt_encoder_mask_decoder_model = prompt_encoder_mask_decoder_model
9596 self ._compile_only = kwargs .get ("compile_only" , False )
97+
98+ self ._openvino_config = None
99+ if quantization_config :
100+ self ._openvino_config = OVConfig (quantization_config = quantization_config )
101+
96102 enable_compilation = kwargs .get ("compile" , True )
97- self .vision_encoder = OVSamVisionEncoder (self . vision_encoder_model , self )
98- self .prompt_encoder_mask_decoder = OVSamPromptEncoder (self . prompt_encoder_mask_decoder_model , self )
103+ self .vision_encoder = OVSamVisionEncoder (vision_encoder_model , self )
104+ self .prompt_encoder_mask_decoder = OVSamPromptEncoder (prompt_encoder_mask_decoder_model , self )
99105
100106 if dynamic_shapes and not self .is_dynamic and not self ._compile_only :
101107 self .reshape ()
@@ -117,9 +123,8 @@ def clear_requests(self):
117123 raise ValueError (
118124 "`clear_requests()` is not supported with `compile_only` mode, please initialize model without this option"
119125 )
120-
121- for _ , component in self .components .items ():
122- component .clear_requests ()
126+ self .vision_encoder .clear_requests ()
127+ self .prompt_encoder_mask_decoder .clear_requests ()
123128
124129 def compile (self ):
125130 self .vision_encoder ._compile ()
@@ -143,15 +148,16 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
143148 """
144149 src_models = self .ov_submodels
145150 dst_file_names = {
146- "vision_encoder_model " : OV_VISION_ENCODER_MODEL_NAME ,
147- "prompt_encoder_mask_decoder_model " : OV_PROMPT_ENCODER_MASK_DECODER_MODEL_NAME ,
151+ "vision_encoder " : OV_VISION_ENCODER_MODEL_NAME ,
152+ "prompt_encoder_mask_decoder " : OV_PROMPT_ENCODER_MASK_DECODER_MODEL_NAME ,
148153 }
149154
150155 for name in self ._ov_submodel_names :
151156 model = src_models [name ]
152157 dst_file_name = dst_file_names [name ]
153158 dst_path = os .path .join (save_directory , dst_file_name )
154159 ov .save_model (model , dst_path , compress_to_fp16 = False )
160+ self ._save_openvino_config (save_directory )
155161
156162 @classmethod
157163 def _from_pretrained (
@@ -166,6 +172,8 @@ def _from_pretrained(
166172 vision_encoder_file_name : Optional [str ] = None ,
167173 prompt_encoder_mask_decoder_file_name : Optional [str ] = None ,
168174 local_files_only : bool = False ,
175+ load_in_8bit : bool = False ,
176+ quantization_config : Union [OVQuantizationConfigBase , Dict ] = None ,
169177 ** kwargs ,
170178 ):
171179 """
@@ -198,6 +206,10 @@ def _from_pretrained(
198206 openvino_prompt_encoder_mask_decoder.xml, allowing to load the decoder model with a different name.
199207 local_files_only(`bool`, *optional*, defaults to `False`):
200208 Whether or not to only look at local files (i.e., do not try to download the model).
209+ load_in_8bit(`bool`, *optional*, defaults to `False`):
210+ Whether or not to apply 8-bit weight quantization.
211+ quantization_config(`Union[OVQuantizationConfigBase, Dict]`, *optional*, defaults to `None`):
212+ Quantization configuration to apply to the model.
201213 """
202214 if use_auth_token is not None :
203215 warnings .warn (
@@ -272,21 +284,50 @@ def _from_pretrained(
272284 model_save_dir ,
273285 )
274286
287+ quantization_config = cls ._prepare_quantization_config (quantization_config , load_in_8bit )
275288 model = cls (
276289 vision_encoder_model = vision_encoder_model ,
277290 prompt_encoder_mask_decoder_model = prompt_encoder_model ,
278291 config = config ,
279292 model_save_dir = model_save_dir ,
293+ quantization_config = quantization_config ,
280294 ** kwargs ,
281295 )
282296
297+ if quantization_config is not None :
298+ from optimum .intel import OVQuantizer
299+
300+ quantizer = OVQuantizer (model )
301+ quantization_config_copy = quantization_config .clone ()
302+ quantization_config_copy .tokenizer = quantization_config .tokenizer or model_id
303+ quantization_config_copy .processor = quantization_config .processor or model_id
304+ quantizer .quantize (ov_config = OVConfig (quantization_config = quantization_config_copy ))
305+
283306 return model
284307
285308 @property
286309 def _ov_submodel_names (self ):
287- model_names = ["vision_encoder_model " , "prompt_encoder_mask_decoder_model " ]
310+ model_names = ["vision_encoder " , "prompt_encoder_mask_decoder " ]
288311 return model_names
289312
313+ @property
314+ def ov_submodels (self ) -> Dict [str , ov .Model ]:
315+ return {component_name : getattr (self , component_name ).model for component_name in self ._ov_submodel_names }
316+
317+ @property
318+ def vision_encoder_model (self ) -> ov .Model :
319+ logger .warning (
320+ "Access to the `vision_encoder_model` attribute is deprecated and will be removed in optimum-intel v1.26, please use `vision_encoder.model` instead"
321+ )
322+ return self .vision_encoder .model
323+
324+ @property
325+ def prompt_encoder_mask_decoder_model (self ) -> ov .Model :
326+ logger .warning (
327+ "Access to the `prompt_encoder_mask_decoder_model` attribute is deprecated and will be removed in optimum-intel v1.26, please use `prompt_encoder_mask_decoder.model` instead"
328+ )
329+ return self .prompt_encoder_mask_decoder .model
330+
290331 def reshape (self , batch_size : int = - 1 , point_batch_size : int = - 1 , num_points_per_image : int = - 1 ):
291332 """
292333 Propagates the given input shapes on the model's layers, fixing the inputs shapes of the model.
@@ -304,19 +345,19 @@ def reshape(self, batch_size: int = -1, point_batch_size: int = -1, num_points_p
304345 "`reshape()` is not supported with `compile_only` mode, please initialize model without this option"
305346 )
306347 vision_encoder_shapes = {}
307- for inputs in self .vision_encoder_model .inputs :
348+ for inputs in self .vision_encoder . model .inputs :
308349 vision_encoder_shapes [inputs ] = inputs .get_partial_shape ()
309350 vision_encoder_shapes [inputs ][0 ] = batch_size
310- self .vision_encoder_model .reshape (vision_encoder_shapes )
351+ self .vision_encoder . model .reshape (vision_encoder_shapes )
311352 self .vision_encoder .request = None
312353 mask_decoder_shapes = {}
313- for inputs in self .prompt_encoder_mask_decoder_model .inputs :
354+ for inputs in self .prompt_encoder_mask_decoder . model .inputs :
314355 mask_decoder_shapes [inputs ] = inputs .get_partial_shape ()
315356 mask_decoder_shapes [inputs ][0 ] = batch_size
316357 if inputs .get_any_name () in ["input_points" , "input_labels" ]:
317358 mask_decoder_shapes [inputs ][1 ] = point_batch_size
318359 mask_decoder_shapes [inputs ][2 ] = num_points_per_image
319- self .prompt_encoder_mask_decoder_model .reshape (mask_decoder_shapes )
360+ self .prompt_encoder_mask_decoder . model .reshape (mask_decoder_shapes )
320361 self .prompt_encoder_mask_decoder .request = None
321362 return self
322363
@@ -398,6 +439,6 @@ def get_image_features(self, pixel_values, *args, **kwargs):
398439
399440 @property
400441 def is_dynamic (self ):
401- return model_has_dynamic_inputs (self .vision_encoder_model ) or model_has_dynamic_inputs (
402- self .prompt_encoder_mask_decoder_model
442+ return model_has_dynamic_inputs (self .vision_encoder . model ) or model_has_dynamic_inputs (
443+ self .prompt_encoder_mask_decoder . model
403444 )
0 commit comments