@@ -78,7 +78,7 @@ def parse_args_openvino(parser: "ArgumentParser"):
7878    optional_group .add_argument (
7979        "--quant-mode" ,
8080        type = str ,
81-         choices = ["int8" , "f8e4m3" , "f8e5m2" ],
81+         choices = ["int8" , "f8e4m3" , "f8e5m2" ,  "nf4_f8e4m3" ],
8282        default = None ,
8383        help = (
8484            "Quantization precision mode. This is used for applying full model quantization including activations. " 
@@ -307,7 +307,14 @@ def parse_args(parser: "ArgumentParser"):
307307    def  run (self ):
308308        from  ...exporters .openvino .__main__  import  infer_task , main_export , maybe_convert_tokenizers 
309309        from  ...exporters .openvino .utils  import  save_preprocessors 
310-         from  ...intel .openvino .configuration  import  _DEFAULT_4BIT_CONFIG , OVConfig , get_default_int4_config 
310+         from  ...intel .openvino .configuration  import  (
311+             _DEFAULT_4BIT_CONFIG ,
312+             OVCompressWeightsOptions ,
313+             OVConfig ,
314+             OVGeneralQuantizationConfig ,
315+             OVQuantizeOptions ,
316+             get_default_int4_config ,
317+         )
311318
312319        if  self .args .library  is  None :
313320            # TODO: add revision, subfolder and token to args 
@@ -342,43 +349,39 @@ def run(self):
342349            if  no_compression_parameter_provided (self .args ) and  self .args .weight_format  ==  "int4" :
343350                quantization_config  =  get_default_int4_config (self .args .model )
344351            else :
345-                 is_int8  =  self .args .weight_format  ==  "int8" 
346-                 quantization_config  =  {
347-                     "bits" : 8  if  is_int8  else  4 ,
348-                     "ratio" : 1  if  is_int8  else  (self .args .ratio  or  _DEFAULT_4BIT_CONFIG ["ratio" ]),
349-                     "sym" : self .args .sym  or  False ,
350-                     "group_size" : - 1  if  is_int8  else  self .args .group_size ,
351-                     "all_layers" : None  if  is_int8  else  self .args .all_layers ,
352-                     "dataset" : self .args .dataset ,
353-                     "num_samples" : self .args .num_samples ,
354-                     "quant_method" : "awq"  if  self .args .awq  else  "default" ,
355-                     "sensitivity_metric" : self .args .sensitivity_metric ,
356-                     "scale_estimation" : self .args .scale_estimation ,
357-                     "gptq" : self .args .gptq ,
358-                     "lora_correction" : self .args .lora_correction ,
359-                     "weight_format" : self .args .weight_format ,
360-                     "backup_precision" : self .args .backup_precision ,
361-                 }
352+                 quantization_config  =  prepare_for_wc_config (self .args , _DEFAULT_4BIT_CONFIG )
362353
363354            if  quantization_config .get ("dataset" , None ) is  not   None :
364355                quantization_config ["trust_remote_code" ] =  self .args .trust_remote_code 
365356            ov_config  =  OVConfig (quantization_config = quantization_config )
366-         else :
357+         elif   self . args . quant_mode   is   not   None :
367358            if  self .args .dataset  is  None :
368359                raise  ValueError (
369360                    "Dataset is required for full quantization. Please provide it with --dataset argument." 
370361                )
371362
372-             quantization_config  =  {
373-                 "weight_format" : self .args .quant_mode ,
374-                 "activation_format" : self .args .quant_mode ,
375-                 "bits" : 8 ,
376-                 "sym" : self .args .sym  or  False ,
377-                 "dataset" : self .args .dataset ,
378-                 "num_samples" : self .args .num_samples ,
379-                 "smooth_quant_alpha" : self .args .smooth_quant_alpha ,
380-                 "trust_remote_code" : self .args .trust_remote_code ,
381-             }
363+             if  self .args .quant_mode  ==  "nf4_f8e4m3" :
364+                 wc_config  =  prepare_for_wc_config (self .args , _DEFAULT_4BIT_CONFIG )
365+                 wc_config ["weight_format" ] =  "nf4" 
366+                 cw_options  =  OVCompressWeightsOptions .init_with_format (** wc_config )
367+ 
368+                 q_config  =  prepare_for_q_config (self .args )
369+                 q_config ["activation_format" ] =  "f8e4m3" 
370+                 q_options  =  OVQuantizeOptions .init_with_format (** q_config )
371+ 
372+                 quantization_config  =  OVGeneralQuantizationConfig .init_with_format (
373+                     bits = 8 ,
374+                     sym = self .args .sym ,
375+                     ignored_scope = None ,
376+                     num_samples = self .args .num_samples ,
377+                     dataset = self .args .dataset ,
378+                     trust_remote_code = self .args .trust_remote_code ,
379+                     weight_format = self .args .weight_format ,
380+                 )
381+                 quantization_config .compress_weights_options  =  cw_options 
382+                 quantization_config .quantize_options  =  q_options 
383+             else :
384+                 quantization_config  =  prepare_for_q_config (self .args )
382385            ov_config  =  OVConfig (quantization_config = quantization_config )
383386
384387        quantization_config  =  ov_config .quantization_config  if  ov_config  else  None 
@@ -470,3 +473,36 @@ def run(self):
470473                library_name = library_name ,
471474                # **input_shapes, 
472475            )
476+ 
477+ 
478+ def  prepare_for_wc_config (args , default_configs ):
479+     is_int8  =  args .weight_format  ==  "int8" 
480+     return  {
481+         "bits" : 8  if  is_int8  else  4 ,
482+         "ratio" : 1  if  is_int8  else  (args .ratio  or  default_configs ["ratio" ]),
483+         "sym" : args .sym  or  False ,
484+         "group_size" : - 1  if  is_int8  else  args .group_size ,
485+         "all_layers" : None  if  is_int8  else  args .all_layers ,
486+         "dataset" : args .dataset ,
487+         "num_samples" : args .num_samples ,
488+         "quant_method" : "awq"  if  args .awq  else  "default" ,
489+         "sensitivity_metric" : args .sensitivity_metric ,
490+         "scale_estimation" : args .scale_estimation ,
491+         "gptq" : args .gptq ,
492+         "lora_correction" : args .lora_correction ,
493+         "weight_format" : args .weight_format ,
494+         "backup_precision" : args .backup_precision ,
495+     }
496+ 
497+ 
498+ def  prepare_for_q_config (args ):
499+     return  {
500+         "weight_format" : args .quant_mode ,
501+         "activation_format" : args .quant_mode ,
502+         "bits" : 8 ,
503+         "sym" : args .sym  or  False ,
504+         "dataset" : args .dataset ,
505+         "num_samples" : args .num_samples ,
506+         "smooth_quant_alpha" : args .smooth_quant_alpha ,
507+         "trust_remote_code" : args .trust_remote_code ,
508+     }
0 commit comments