@@ -187,6 +187,65 @@ def get_16a8w_qnn_ptq_config(
187
187
return quantization_config
188
188
189
189
190
+ def get_16a8w_qnn_qat_config (
191
+ act_observer = MovingAverageMinMaxObserver ,
192
+ ) -> QuantizationConfig :
193
+ extra_args : Dict [str , Any ] = {"eps" : 2 ** - 20 }
194
+ act_fake_quant_ctr = FakeQuantize .with_args (
195
+ dtype = torch .int32 ,
196
+ quant_min = torch .iinfo (torch .uint16 ).min ,
197
+ quant_max = torch .iinfo (torch .uint16 ).max ,
198
+ qscheme = torch .per_tensor_affine ,
199
+ reduce_range = True ,
200
+ observer = act_observer .with_args (** extra_args ),
201
+ )
202
+ act_quantization_spec = QuantizationSpec (
203
+ dtype = torch .int32 ,
204
+ quant_min = torch .iinfo (torch .uint16 ).min ,
205
+ quant_max = torch .iinfo (torch .uint16 ).max ,
206
+ qscheme = torch .per_tensor_affine ,
207
+ observer_or_fake_quant_ctr = act_fake_quant_ctr ,
208
+ )
209
+ weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
210
+ dtype = torch .int8 ,
211
+ quant_min = torch .iinfo (torch .int8 ).min + 1 ,
212
+ quant_max = torch .iinfo (torch .int8 ).max ,
213
+ qscheme = torch .per_tensor_symmetric ,
214
+ reduce_range = True ,
215
+ observer = MovingAverageMinMaxObserver ,
216
+ )
217
+ weight_quantization_spec = QuantizationSpec (
218
+ dtype = torch .int8 ,
219
+ quant_min = torch .iinfo (torch .int8 ).min + 1 ,
220
+ quant_max = torch .iinfo (torch .int8 ).max ,
221
+ qscheme = torch .per_tensor_symmetric ,
222
+ ch_axis = 0 ,
223
+ observer_or_fake_quant_ctr = weight_fake_quant_ctr ,
224
+ )
225
+ bias_fake_quant_ctr = FakeQuantize .with_args (
226
+ dtype = torch .int32 ,
227
+ quant_min = torch .iinfo (torch .int32 ).min ,
228
+ quant_max = torch .iinfo (torch .int32 ).max ,
229
+ qscheme = torch .per_tensor_symmetric ,
230
+ observer = MovingAverageMinMaxObserver ,
231
+ )
232
+ bias_quantization_spec = QuantizationSpec (
233
+ dtype = torch .int32 ,
234
+ quant_min = torch .iinfo (torch .int32 ).min ,
235
+ quant_max = torch .iinfo (torch .int32 ).max ,
236
+ qscheme = torch .per_tensor_symmetric ,
237
+ observer_or_fake_quant_ctr = bias_fake_quant_ctr ,
238
+ )
239
+ quantization_config = QuantizationConfig (
240
+ input_activation = act_quantization_spec ,
241
+ output_activation = act_quantization_spec ,
242
+ weight = weight_quantization_spec ,
243
+ bias = bias_quantization_spec ,
244
+ )
245
+
246
+ return quantization_config
247
+
248
+
190
249
def get_16a16w_qnn_ptq_config (
191
250
act_observer = MovingAverageMinMaxObserver ,
192
251
) -> QuantizationConfig :
@@ -459,6 +518,7 @@ def get_qat_per_channel_quant_config(
459
518
act_dtype = torch .uint8 ,
460
519
weight_dtype = torch .int8 ,
461
520
act_observer = MovingAverageMinMaxObserver ,
521
+ act_symmetric = False ,
462
522
) -> QuantizationConfig :
463
523
supported_act_types = {
464
524
torch .uint8 ,
@@ -476,21 +536,38 @@ def get_qat_per_channel_quant_config(
476
536
), f"weight_dtype, { weight_dtype } is not one of supported types, { supported_weight_dtypes } "
477
537
478
538
# torch does not support uint16 quantization, use int32 to bypass
479
- act_fake_quant_ctr = FakeQuantize .with_args (
480
- dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
481
- quant_min = torch .iinfo (act_dtype ).min ,
482
- quant_max = torch .iinfo (act_dtype ).max ,
483
- qscheme = torch .per_tensor_affine ,
484
- reduce_range = True ,
485
- observer = act_observer ,
486
- )
487
- act_quantization_spec = QuantizationSpec (
488
- dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
489
- quant_min = torch .iinfo (act_dtype ).min ,
490
- quant_max = torch .iinfo (act_dtype ).max ,
491
- qscheme = torch .per_tensor_affine ,
492
- observer_or_fake_quant_ctr = act_fake_quant_ctr ,
493
- )
539
+ if act_symmetric :
540
+ # If zero_point is 128, htp can do optimizations.
541
+ # If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
542
+ # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
543
+ act_fake_quant_ctr = FakeQuantize .with_args (
544
+ dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
545
+ qscheme = torch .per_tensor_symmetric ,
546
+ reduce_range = True ,
547
+ observer = act_observer ,
548
+ )
549
+ act_quantization_spec = QuantizationSpec (
550
+ dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
551
+ qscheme = torch .per_tensor_symmetric ,
552
+ ch_axis = 0 ,
553
+ observer_or_fake_quant_ctr = act_fake_quant_ctr ,
554
+ )
555
+ else :
556
+ act_fake_quant_ctr = FakeQuantize .with_args (
557
+ dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
558
+ quant_min = torch .iinfo (act_dtype ).min ,
559
+ quant_max = torch .iinfo (act_dtype ).max ,
560
+ qscheme = torch .per_tensor_affine ,
561
+ reduce_range = True ,
562
+ observer = act_observer ,
563
+ )
564
+ act_quantization_spec = QuantizationSpec (
565
+ dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
566
+ quant_min = torch .iinfo (act_dtype ).min ,
567
+ quant_max = torch .iinfo (act_dtype ).max ,
568
+ qscheme = torch .per_tensor_affine ,
569
+ observer_or_fake_quant_ctr = act_fake_quant_ctr ,
570
+ )
494
571
495
572
weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
496
573
dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
@@ -513,7 +590,21 @@ def get_qat_per_channel_quant_config(
513
590
observer_or_fake_quant_ctr = weight_fake_quant_ctr ,
514
591
)
515
592
516
- bias_quantization_spec = _derived_bias_quant_spec
593
+ bias_fake_quant_ctr = FakeQuantize .with_args (
594
+ dtype = torch .int32 ,
595
+ quant_min = torch .iinfo (torch .int32 ).min ,
596
+ quant_max = torch .iinfo (torch .int32 ).max ,
597
+ qscheme = torch .per_tensor_symmetric ,
598
+ reduce_range = True ,
599
+ observer = MovingAverageMinMaxObserver ,
600
+ )
601
+ bias_quantization_spec = QuantizationSpec (
602
+ dtype = torch .int32 ,
603
+ quant_min = torch .iinfo (torch .int32 ).min ,
604
+ quant_max = torch .iinfo (torch .int32 ).max ,
605
+ qscheme = torch .per_tensor_symmetric ,
606
+ observer_or_fake_quant_ctr = bias_fake_quant_ctr ,
607
+ )
517
608
518
609
quantization_config = QuantizationConfig (
519
610
input_activation = act_quantization_spec ,
0 commit comments