@@ -200,7 +200,7 @@ def get_16a8w_qnn_qat_config(
200200 act_observer = MovingAverageMinMaxObserver ,
201201) -> QuantizationConfig :
202202 extra_args : Dict [str , Any ] = {"eps" : 2 ** - 20 }
203- act_fake_quant_ctr = FakeQuantize .with_args (
203+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
204204 dtype = torch .int32 ,
205205 quant_min = torch .iinfo (torch .uint16 ).min ,
206206 quant_max = torch .iinfo (torch .uint16 ).max ,
@@ -398,7 +398,7 @@ def get_ptq_per_block_quant_config(
398398def get_8a8w_qnn_qat_config (
399399 act_symmetric : bool = False , act_observer = MovingAverageMinMaxObserver
400400) -> QuantizationConfig :
401- act_fake_quant_ctr = FakeQuantize .with_args (
401+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
402402 dtype = torch .uint8 ,
403403 qscheme = (
404404 torch .per_tensor_symmetric if act_symmetric else torch .per_tensor_affine
@@ -458,7 +458,7 @@ def get_8a8w_qnn_qat_config(
458458def get_16a4w_qnn_qat_config (
459459 act_observer = MovingAverageMinMaxObserver ,
460460) -> QuantizationConfig :
461- act_fake_quant_ctr = FakeQuantize .with_args (
461+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
462462 dtype = torch .int32 ,
463463 quant_min = torch .iinfo (torch .uint16 ).min ,
464464 quant_max = torch .iinfo (torch .uint16 ).max ,
@@ -541,7 +541,7 @@ def get_qat_per_channel_quant_config(
541541 # If zero_point is 128, htp can do optimizations.
542542 # If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
543543 # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
544- act_fake_quant_ctr = FakeQuantize .with_args (
544+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
545545 dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
546546 qscheme = torch .per_tensor_symmetric ,
547547 observer = act_observer ,
@@ -553,7 +553,7 @@ def get_qat_per_channel_quant_config(
553553 observer_or_fake_quant_ctr = act_fake_quant_ctr ,
554554 )
555555 else :
556- act_fake_quant_ctr = FakeQuantize .with_args (
556+ act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
557557 dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
558558 quant_min = torch .iinfo (act_dtype ).min ,
559559 quant_max = torch .iinfo (act_dtype ).max ,
0 commit comments