@@ -315,31 +315,6 @@ def quantize(
315315 else :
316316 raise TypeError (f"Unsupported model type: { type (self .model )} " )
317317
318- def _check_model_state (self , sub_model_names : List [str ] = None ):
319- message_template = (
320- "Couldn't apply optimization to the model because it was already compressed with config: {}. "
321- "To avoid this issue, set load_in_8bit=False in the from_pretrained method when using the optimum-intel API, "
322- "or explicitly specify the desired weight format using --weight_format fp16/fp32 for CLI."
323- )
324-
325- def check_rt_info (ov_model ):
326- rt_info = ov_model .get_rt_info ()
327- if "nncf" in rt_info :
328- model_weight_compression_config = rt_info ["nncf" ].get ("weight_compression" , None )
329- model_quantization_config = rt_info ["nncf" ].get ("quantization" , None )
330- if model_weight_compression_config is not None :
331- raise RuntimeError (message_template .format (model_weight_compression_config ))
332- elif model_quantization_config is not None :
333- raise RuntimeError (message_template .format (model_quantization_config ))
334-
335- if sub_model_names is None :
336- check_rt_info (self .model .model )
337- else :
338- for name in sub_model_names :
339- if hasattr (self .model , name ):
340- ov_model = getattr (self .model , name ).model
341- check_rt_info (ov_model )
342-
343318 def _quantize_ovbasemodel (
344319 self ,
345320 ov_config : OVConfig ,
@@ -350,7 +325,7 @@ def _quantize_ovbasemodel(
350325 remove_unused_columns : bool = True ,
351326 ** kwargs ,
352327 ):
353- from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper , OVModelForSeq2SeqLM
328+ from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper
354329 from optimum .intel .openvino .modeling_visual_language import OVModelForVisualCausalLM
355330
356331 if is_diffusers_available ():
@@ -429,7 +404,6 @@ def _quantize_ovbasemodel(
429404 "text_encoder_2" ,
430405 "text_encoder_3" ,
431406 ]
432- self ._check_model_state (sub_model_names )
433407 sub_models = filter (lambda x : x , (getattr (self .model , name ) for name in sub_model_names ))
434408 for sub_model in sub_models :
435409 _weight_only_quantization (sub_model .model , quantization_config_copy , ** kwargs )
@@ -447,7 +421,6 @@ def _quantize_ovbasemodel(
447421 self .model .clear_requests ()
448422 else :
449423 # The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc.
450- self ._check_model_state ()
451424 self .model .model = _hybrid_quantization (
452425 self .model .model , quantization_config , calibration_dataset , ** kwargs
453426 )
@@ -463,31 +436,19 @@ def _quantize_ovbasemodel(
463436 "transformer" ,
464437 "text_encoder_3" ,
465438 ]
466- self ._check_model_state (sub_model_names )
467439 sub_models = filter (lambda x : x , (getattr (self .model , name ) for name in sub_model_names ))
468440 for sub_model in sub_models :
469441 _weight_only_quantization (sub_model .model , quantization_config , ** kwargs )
470442 self .model .clear_requests ()
471443 elif isinstance (self .model , OVModelForVisualCausalLM ):
472444 language_model = self .model .language_model
473- sub_model_names = ["vision_embeddings" , "text_embeddings" ] + self .model .additional_parts
474- self ._check_model_state (sub_model_names + ["language_model" ])
475445 _weight_only_quantization (language_model .model , quantization_config , calibration_dataset , ** kwargs )
446+ sub_model_names = ["vision_embeddings" , "text_embeddings" ] + self .model .additional_parts
476447 sub_models = [getattr (self .model , f"{ name } _model" ) for name in sub_model_names ]
477448 for sub_model in sub_models :
478449 _weight_only_quantization (sub_model , OVWeightQuantizationConfig (bits = 8 , sym = True ), ** kwargs )
479450 self .model .clear_requests ()
480- elif isinstance (self .model , OVModelForSeq2SeqLM ):
481- sub_model_names = ["encoder" , "decoder" ]
482- if self .model .decoder_with_past is not None :
483- sub_model_names .append ("decoder_with_past" )
484- self ._check_model_state (sub_model_names )
485- sub_models = [getattr (self .model , name ) for name in sub_model_names ]
486- for sub_model in sub_models :
487- _weight_only_quantization (sub_model , quantization_config , ** kwargs )
488- self .model .clear_requests ()
489451 else :
490- self ._check_model_state ()
491452 _weight_only_quantization (self .model .model , quantization_config , calibration_dataset , ** kwargs )
492453 self .model .request = None
493454 else :
@@ -499,7 +460,6 @@ def _quantize_ovbasemodel(
499460
500461 # Quantize model(s)
501462 if isinstance (self .model , _OVModelForWhisper ):
502- self ._check_model_state (["encoder_model" , "decoder_model" , "decoder_with_past_model" ])
503463 self ._quantize_whisper_model (quantization_config , calibration_dataset , ** kwargs )
504464 else :
505465 quantized_model = _full_quantization (
@@ -1050,6 +1010,7 @@ def _weight_only_quantization(
10501010 calibration_dataset : Optional [Union [nncf .Dataset , Iterable ]] = None ,
10511011 ** kwargs ,
10521012) -> openvino .runtime .Model :
1013+ _verify_not_optimized (model )
10531014 config = quantization_config
10541015 if isinstance (config , dict ):
10551016 config = OVWeightQuantizationConfig .from_dict (quantization_config )
@@ -1106,6 +1067,7 @@ def _full_quantization(
11061067 calibration_dataset : nncf .Dataset ,
11071068 ** kwargs ,
11081069):
1070+ _verify_not_optimized (model )
11091071 advanced_parameters_kwargs = {}
11101072 if quantization_config .smooth_quant_alpha is not None :
11111073 advanced_parameters_kwargs ["smooth_quant_alphas" ] = AdvancedSmoothQuantParameters (
@@ -1227,3 +1189,20 @@ def _hybrid_quantization(
12271189 ** kwargs ,
12281190 )
12291191 return quantized_model
1192+
1193+
1194+ def _verify_not_optimized (ov_model ):
1195+ message_template = (
1196+ "Cannot apply optimization to the model because it was already optimized with the following config: {}. "
1197+ "To avoid this issue, check that you set load_in_8bit=False or not using quantization_config at export in the .from_pretrained(), "
1198+ "or explicitly specify weight format with --weight_format fp16/fp32 when using CLI."
1199+ )
1200+
1201+ rt_info = ov_model .get_rt_info ()
1202+ if "nncf" in rt_info :
1203+ model_weight_compression_config = rt_info ["nncf" ].get ("weight_compression" , None )
1204+ model_quantization_config = rt_info ["nncf" ].get ("quantization" , None )
1205+ if model_weight_compression_config is not None :
1206+ raise RuntimeError (message_template .format (model_weight_compression_config ))
1207+ elif model_quantization_config is not None :
1208+ raise RuntimeError (message_template .format (model_quantization_config ))
0 commit comments