@@ -78,7 +78,7 @@ def parse_args_openvino(parser: "ArgumentParser"):
7878 optional_group .add_argument (
7979 "--quant-mode" ,
8080 type = str ,
81- choices = ["int8" , "f8e4m3" , "f8e5m2" ],
81+ choices = ["int8" , "f8e4m3" , "f8e5m2" , "nf4_f8e4m3" , "nf4_f8e5m2" , "int4_f8e4m3" , "int4_f8e5m2" ],
8282 default = None ,
8383 help = (
8484 "Quantization precision mode. This is used for applying full model quantization including activations. "
@@ -352,23 +352,7 @@ def run(self):
352352 if no_compression_parameter_provided (self .args ) and self .args .weight_format == "int4" :
353353 quantization_config = get_default_int4_config (self .args .model )
354354 else :
355- is_int8 = self .args .weight_format == "int8"
356- quantization_config = {
357- "bits" : 8 if is_int8 else 4 ,
358- "ratio" : 1.0 if is_int8 else (self .args .ratio or _DEFAULT_4BIT_CONFIG ["ratio" ]),
359- "sym" : self .args .sym or False ,
360- "group_size" : - 1 if is_int8 else self .args .group_size ,
361- "all_layers" : None if is_int8 else self .args .all_layers ,
362- "dataset" : self .args .dataset ,
363- "num_samples" : self .args .num_samples ,
364- "quant_method" : "awq" if self .args .awq else "default" ,
365- "sensitivity_metric" : self .args .sensitivity_metric ,
366- "scale_estimation" : self .args .scale_estimation ,
367- "gptq" : self .args .gptq ,
368- "lora_correction" : self .args .lora_correction ,
369- "weight_format" : self .args .weight_format ,
370- "backup_precision" : self .args .backup_precision ,
371- }
355+ quantization_config = prepare_wc_config (self .args , _DEFAULT_4BIT_CONFIG )
372356
373357 if quantization_config .get ("dataset" , None ) is not None :
374358 quantization_config ["trust_remote_code" ] = self .args .trust_remote_code
@@ -378,16 +362,24 @@ def run(self):
378362 raise ValueError (
379363 "Dataset is required for full quantization. Please provide it with --dataset argument."
380364 )
381- quantization_config = {
382- "weight_format" : self .args .quant_mode ,
383- "activation_format" : self .args .quant_mode ,
384- "bits" : 8 ,
385- "sym" : self .args .sym or False ,
386- "dataset" : self .args .dataset ,
387- "num_samples" : self .args .num_samples ,
388- "smooth_quant_alpha" : self .args .smooth_quant_alpha ,
389- "trust_remote_code" : self .args .trust_remote_code ,
390- }
365+
366+ if self .args .quant_mode in ["nf4_f8e4m3" , "nf4_f8e5m2" , "int4_f8e4m3" , "int4_f8e5m2" ]:
367+ wc_config = prepare_wc_config (self .args , _DEFAULT_4BIT_CONFIG )
368+ wc_dtype , q_dtype = self .args .quant_mode .split ("_" )
369+ wc_config ["dtype" ] = wc_dtype
370+
371+ q_config = prepare_q_config (self .args )
372+ q_config ["dtype" ] = q_dtype
373+
374+ quantization_config = {
375+ "weight_quantization_config" : wc_config ,
376+ "full_quantization_config" : q_config ,
377+ "num_samples" : self .args .num_samples ,
378+ "dataset" : self .args .dataset ,
379+ "trust_remote_code" : self .args .trust_remote_code ,
380+ }
381+ else :
382+ quantization_config = prepare_q_config (self .args )
391383 ov_config = OVConfig (quantization_config = quantization_config )
392384
393385 quantization_config = ov_config .quantization_config if ov_config else None
@@ -486,3 +478,35 @@ def run(self):
486478 variant = self .args .variant ,
487479 # **input_shapes,
488480 )
481+
482+
483+ def prepare_wc_config (args , default_configs ):
484+ is_int8 = args .weight_format == "int8"
485+ return {
486+ "bits" : 8 if is_int8 else 4 ,
487+ "ratio" : 1.0 if is_int8 else (args .ratio or default_configs ["ratio" ]),
488+ "sym" : args .sym or False ,
489+ "group_size" : - 1 if is_int8 else args .group_size ,
490+ "all_layers" : None if is_int8 else args .all_layers ,
491+ "dataset" : args .dataset ,
492+ "num_samples" : args .num_samples ,
493+ "quant_method" : "awq" if args .awq else "default" ,
494+ "sensitivity_metric" : args .sensitivity_metric ,
495+ "scale_estimation" : args .scale_estimation ,
496+ "gptq" : args .gptq ,
497+ "lora_correction" : args .lora_correction ,
498+ "dtype" : args .weight_format ,
499+ "backup_precision" : args .backup_precision ,
500+ }
501+
502+
503+ def prepare_q_config (args ):
504+ return {
505+ "dtype" : args .quant_mode ,
506+ "bits" : 8 ,
507+ "sym" : args .sym or False ,
508+ "dataset" : args .dataset ,
509+ "num_samples" : args .num_samples ,
510+ "smooth_quant_alpha" : args .smooth_quant_alpha ,
511+ "trust_remote_code" : args .trust_remote_code ,
512+ }
0 commit comments