@@ -315,6 +315,31 @@ def quantize(
315315 else :
316316 raise TypeError (f"Unsupported model type: { type (self .model )} " )
317317
318+ def _check_model_state (self , sub_model_names : List [str ] = None ):
319+ message_template = (
320+ "Couldn't apply optimization to the model because it was already compressed with config: {}. "
321+ "To avoid this issue, set load_in_8bit=False in the from_pretrained method when using the optimum-intel API, "
322+ "or explicitly specify the desired weight format using --weight_format fp16/fp32 for CLI."
323+ )
324+
325+ def check_rt_info (ov_model ):
326+ rt_info = ov_model .get_rt_info ()
327+ if "nncf" in rt_info :
328+ model_weight_compression_config = rt_info ["nncf" ].get ("weight_compression" , None )
329+ model_quantization_config = rt_info ["nncf" ].get ("quantization" , None )
330+ if model_weight_compression_config is not None :
331+ raise RuntimeError (message_template .format (model_weight_compression_config ))
332+ elif model_quantization_config is not None :
333+ raise RuntimeError (message_template .format (model_quantization_config ))
334+
335+ if sub_model_names is None :
336+ check_rt_info (self .model .model )
337+ else :
338+ for name in sub_model_names :
339+ if hasattr (self .model , name ):
340+ ov_model = getattr (self .model , name ).model
341+ check_rt_info (ov_model )
342+
318343 def _quantize_ovbasemodel (
319344 self ,
320345 ov_config : OVConfig ,
@@ -325,7 +350,7 @@ def _quantize_ovbasemodel(
325350 remove_unused_columns : bool = True ,
326351 ** kwargs ,
327352 ):
328- from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper
353+ from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper , OVModelForSeq2SeqLM
329354 from optimum .intel .openvino .modeling_visual_language import OVModelForVisualCausalLM
330355
331356 if is_diffusers_available ():
@@ -404,6 +429,7 @@ def _quantize_ovbasemodel(
404429 "text_encoder_2" ,
405430 "text_encoder_3" ,
406431 ]
432+ self ._check_model_state (sub_model_names )
407433 sub_models = filter (lambda x : x , (getattr (self .model , name ) for name in sub_model_names ))
408434 for sub_model in sub_models :
409435 _weight_only_quantization (sub_model .model , quantization_config_copy , ** kwargs )
@@ -421,6 +447,7 @@ def _quantize_ovbasemodel(
421447 self .model .clear_requests ()
422448 else :
423449 # The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc.
450+ self ._check_model_state ()
424451 self .model .model = _hybrid_quantization (
425452 self .model .model , quantization_config , calibration_dataset , ** kwargs
426453 )
@@ -436,19 +463,31 @@ def _quantize_ovbasemodel(
436463 "transformer" ,
437464 "text_encoder_3" ,
438465 ]
466+ self ._check_model_state (sub_model_names )
439467 sub_models = filter (lambda x : x , (getattr (self .model , name ) for name in sub_model_names ))
440468 for sub_model in sub_models :
441469 _weight_only_quantization (sub_model .model , quantization_config , ** kwargs )
442470 self .model .clear_requests ()
443471 elif isinstance (self .model , OVModelForVisualCausalLM ):
444472 language_model = self .model .language_model
445- _weight_only_quantization (language_model .model , quantization_config , calibration_dataset , ** kwargs )
446473 sub_model_names = ["vision_embeddings" , "text_embeddings" ] + self .model .additional_parts
474+ self ._check_model_state (sub_model_names + ["language_model" ])
475+ _weight_only_quantization (language_model .model , quantization_config , calibration_dataset , ** kwargs )
447476 sub_models = [getattr (self .model , f"{ name } _model" ) for name in sub_model_names ]
448477 for sub_model in sub_models :
449478 _weight_only_quantization (sub_model , OVWeightQuantizationConfig (bits = 8 , sym = True ), ** kwargs )
450479 self .model .clear_requests ()
480+ elif isinstance (self .model , OVModelForSeq2SeqLM ):
481+ sub_model_names = ["encoder" , "decoder" ]
482+ if self .model .decoder_with_past is not None :
483+ sub_model_names .append ("decoder_with_past" )
484+ self ._check_model_state (sub_model_names )
485+ sub_models = [getattr (self .model , name ) for name in sub_model_names ]
486+ for sub_model in sub_models :
487+ _weight_only_quantization (sub_model , quantization_config , ** kwargs )
488+ self .model .clear_requests ()
451489 else :
490+ self ._check_model_state ()
452491 _weight_only_quantization (self .model .model , quantization_config , calibration_dataset , ** kwargs )
453492 self .model .request = None
454493 else :
@@ -460,6 +499,7 @@ def _quantize_ovbasemodel(
460499
461500 # Quantize model(s)
462501 if isinstance (self .model , _OVModelForWhisper ):
502+ self ._check_model_state (["encoder_model" , "decoder_model" , "decoder_with_past_model" ])
463503 self ._quantize_whisper_model (quantization_config , calibration_dataset , ** kwargs )
464504 else :
465505 quantized_model = _full_quantization (
0 commit comments