@@ -200,7 +200,7 @@ def get_16a8w_qnn_qat_config(
200
200
act_observer = MovingAverageMinMaxObserver ,
201
201
) -> QuantizationConfig :
202
202
extra_args : Dict [str , Any ] = {"eps" : 2 ** - 20 }
203
- act_fake_quant_ctr = FakeQuantize .with_args (
203
+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
204
204
dtype = torch .int32 ,
205
205
quant_min = torch .iinfo (torch .uint16 ).min ,
206
206
quant_max = torch .iinfo (torch .uint16 ).max ,
@@ -398,7 +398,7 @@ def get_ptq_per_block_quant_config(
398
398
def get_8a8w_qnn_qat_config (
399
399
act_symmetric : bool = False , act_observer = MovingAverageMinMaxObserver
400
400
) -> QuantizationConfig :
401
- act_fake_quant_ctr = FakeQuantize .with_args (
401
+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
402
402
dtype = torch .uint8 ,
403
403
qscheme = (
404
404
torch .per_tensor_symmetric if act_symmetric else torch .per_tensor_affine
@@ -458,7 +458,7 @@ def get_8a8w_qnn_qat_config(
458
458
def get_16a4w_qnn_qat_config (
459
459
act_observer = MovingAverageMinMaxObserver ,
460
460
) -> QuantizationConfig :
461
- act_fake_quant_ctr = FakeQuantize .with_args (
461
+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
462
462
dtype = torch .int32 ,
463
463
quant_min = torch .iinfo (torch .uint16 ).min ,
464
464
quant_max = torch .iinfo (torch .uint16 ).max ,
@@ -541,7 +541,7 @@ def get_qat_per_channel_quant_config(
541
541
# If zero_point is 128, htp can do optimizations.
542
542
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
543
543
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
544
- act_fake_quant_ctr = FakeQuantize .with_args (
544
+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
545
545
dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
546
546
qscheme = torch .per_tensor_symmetric ,
547
547
observer = act_observer ,
@@ -553,7 +553,7 @@ def get_qat_per_channel_quant_config(
553
553
observer_or_fake_quant_ctr = act_fake_quant_ctr ,
554
554
)
555
555
else :
556
- act_fake_quant_ctr = FakeQuantize .with_args (
556
+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
557
557
dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
558
558
quant_min = torch .iinfo (act_dtype ).min ,
559
559
quant_max = torch .iinfo (act_dtype ).max ,
0 commit comments