@@ -365,7 +365,7 @@ def main(args):
365
365
f"\n --Quantize-Script-- algo={ args .algo } , dataset={ args .dataset } , calib_size={ args .calib_size } , "
366
366
f"batch_size={ args .batch_size } , block_size={ args .block_size } , add-position-ids={ args .add_position_ids } , "
367
367
f"past-kv={ args .add_past_kv_inputs } , rcalib={ args .use_random_calib } , device={ args .device } , "
368
- f"use_zero_point={ args .use_zero_point } , use_fp32={ args .use_fp32 } \n "
368
+ f"use_zero_point={ args .use_zero_point } , use_fp32={ args .use_fp32 } k_quant_mixed= { args . k_quant_mixed } \n "
369
369
)
370
370
371
371
print (
@@ -435,6 +435,8 @@ def main(args):
435
435
awqclip_alpha_step = args .awqclip_alpha_step ,
436
436
awqclip_alpha_min = args .awqclip_alpha_min ,
437
437
awqclip_bsz_col = args .awqclip_bsz_col ,
438
+ k_quant_mixed = args .k_quant_mixed ,
439
+ int8_layers = args .int8_layers ,
438
440
)
439
441
logging .info (f"\n Quantization process took { time .time () - t } seconds" )
440
442
@@ -594,6 +596,20 @@ def main(args):
594
596
default = False ,
595
597
action = "store_true" ,
596
598
)
597
-
599
+ parser .add_argument (
600
+ "--k_quant_mixed" ,
601
+ default = False ,
602
+ action = "store_true" ,
603
+ help = "True when we want to use k_quant_mixed quantization" ,
604
+ )
605
+ parser .add_argument (
606
+ "--int8_layers" ,
607
+ type = str ,
608
+ default = "" ,
609
+ help = (
610
+ "Comma-separated list of layer patterns to quantize to INT8 instead of INT4."
611
+ "Example: 'layers.0,layers.1,lm_head'"
612
+ ),
613
+ )
598
614
args = parser .parse_args ()
599
615
main (args )
0 commit comments