Skip to content

Commit e652746

Browse files
authored
Use FusedMovingAvgObsFakeQuantize instead of FakeQuantize for faster QAT
Differential Revision: D83583655 Pull Request resolved: #14740
1 parent 822a711 commit e652746

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

backends/qualcomm/quantizer/qconfig.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_16a8w_qnn_qat_config(
200200
act_observer=MovingAverageMinMaxObserver,
201201
) -> QuantizationConfig:
202202
extra_args: Dict[str, Any] = {"eps": 2**-20}
203-
act_fake_quant_ctr = FakeQuantize.with_args(
203+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
204204
dtype=torch.int32,
205205
quant_min=torch.iinfo(torch.uint16).min,
206206
quant_max=torch.iinfo(torch.uint16).max,
@@ -398,7 +398,7 @@ def get_ptq_per_block_quant_config(
398398
def get_8a8w_qnn_qat_config(
399399
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
400400
) -> QuantizationConfig:
401-
act_fake_quant_ctr = FakeQuantize.with_args(
401+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
402402
dtype=torch.uint8,
403403
qscheme=(
404404
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
@@ -458,7 +458,7 @@ def get_8a8w_qnn_qat_config(
458458
def get_16a4w_qnn_qat_config(
459459
act_observer=MovingAverageMinMaxObserver,
460460
) -> QuantizationConfig:
461-
act_fake_quant_ctr = FakeQuantize.with_args(
461+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
462462
dtype=torch.int32,
463463
quant_min=torch.iinfo(torch.uint16).min,
464464
quant_max=torch.iinfo(torch.uint16).max,
@@ -541,7 +541,7 @@ def get_qat_per_channel_quant_config(
541541
# If zero_point is 128, htp can do optimizations.
542542
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
543543
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
544-
act_fake_quant_ctr = FakeQuantize.with_args(
544+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
545545
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
546546
qscheme=torch.per_tensor_symmetric,
547547
observer=act_observer,
@@ -553,7 +553,7 @@ def get_qat_per_channel_quant_config(
553553
observer_or_fake_quant_ctr=act_fake_quant_ctr,
554554
)
555555
else:
556-
act_fake_quant_ctr = FakeQuantize.with_args(
556+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
557557
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
558558
quant_min=torch.iinfo(act_dtype).min,
559559
quant_max=torch.iinfo(act_dtype).max,

0 commit comments

Comments
 (0)