@@ -241,8 +241,7 @@ def get_ptq_per_channel_quant_config(
241
241
torch .int8 ,
242
242
torch .int16 ,
243
243
}
244
- # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
245
- supported_weight_dtypes = {"int4" , torch .int8 , torch .int16 }
244
+ supported_weight_dtypes = {torch .int4 , torch .int8 , torch .int16 }
246
245
assert (
247
246
act_dtype in supported_act_types
248
247
), f"act_dtype, { act_dtype } is not one of supported types, { supported_act_types } "
@@ -276,9 +275,11 @@ def get_ptq_per_channel_quant_config(
276
275
)
277
276
278
277
weight_quantization_spec = QuantizationSpec (
279
- dtype = torch .int8 if weight_dtype == "int4" else weight_dtype ,
280
- quant_min = - 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).min + 1 ,
281
- quant_max = 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).max ,
278
+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
279
+ quant_min = (
280
+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
281
+ ),
282
+ quant_max = 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ,
282
283
qscheme = torch .per_channel_symmetric ,
283
284
ch_axis = 0 ,
284
285
observer_or_fake_quant_ctr = PerChannelMinMaxObserver .with_args (** extra_args ),
@@ -310,9 +311,11 @@ def get_ptq_per_block_quant_config(
310
311
act_symmetric = act_symmetric ,
311
312
)
312
313
weight_quantization_spec = QuantizationSpec (
313
- dtype = torch .int8 if weight_dtype == "int4" else weight_dtype ,
314
- quant_min = - 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).min + 1 ,
315
- quant_max = 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).max ,
314
+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
315
+ quant_min = (
316
+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
317
+ ),
318
+ quant_max = 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ,
316
319
qscheme = torch .per_channel_symmetric ,
317
320
ch_axis = 0 ,
318
321
observer_or_fake_quant_ctr = PerBlockParamObserver .with_args (** extra_args ),
@@ -463,8 +466,7 @@ def get_qat_per_channel_quant_config(
463
466
torch .int8 ,
464
467
torch .int16 ,
465
468
}
466
- # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
467
- supported_weight_dtypes = {"int4" , torch .int8 , torch .int16 }
469
+ supported_weight_dtypes = {torch .int4 , torch .int8 , torch .int16 }
468
470
assert (
469
471
act_dtype in supported_act_types
470
472
), f"act_dtype, { act_dtype } is not one of supported types, { supported_act_types } "
@@ -491,17 +493,21 @@ def get_qat_per_channel_quant_config(
491
493
)
492
494
493
495
weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
494
- dtype = torch .int8 if weight_dtype == "int4" else weight_dtype ,
495
- quant_min = - 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).min + 1 ,
496
- quant_max = 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).max ,
496
+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
497
+ quant_min = (
498
+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
499
+ ),
500
+ quant_max = 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ,
497
501
qscheme = torch .per_channel_symmetric ,
498
502
ch_axis = 0 ,
499
503
observer = MovingAveragePerChannelMinMaxObserver ,
500
504
)
501
505
weight_quantization_spec = QuantizationSpec (
502
- dtype = torch .int8 if weight_dtype == "int4" else weight_dtype ,
503
- quant_min = - 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).min + 1 ,
504
- quant_max = 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).max ,
506
+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
507
+ quant_min = (
508
+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
509
+ ),
510
+ quant_max = 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ,
505
511
qscheme = torch .per_channel_symmetric ,
506
512
ch_axis = 0 ,
507
513
observer_or_fake_quant_ctr = weight_fake_quant_ctr ,
0 commit comments