Skip to content

Commit fcc7f3b

Browse files
authored
Arm backend: Change QAT weight observer (#11787)
We change QAT weight observer from FusedMovingAvgObsFakeQuantize to FakeQuantize with a true MinMaxObserver Signed-off-by: Elena Zhelezina <[email protected]>
1 parent 5006a14 commit fcc7f3b

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
HistogramObserver,
3939
MinMaxObserver,
4040
MovingAverageMinMaxObserver,
41-
MovingAveragePerChannelMinMaxObserver,
4241
ObserverOrFakeQuantizeConstructor,
4342
PerChannelMinMaxObserver,
4443
PlaceholderObserver,
@@ -95,24 +94,26 @@ def get_symmetric_quantization_config(
9594
**extra_args,
9695
),
9796
)
97+
98+
# Setup quantization config for weights
9899
weight_qscheme = (
99100
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
100101
)
101102
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
102103
MinMaxObserver
103104
)
105+
# Determine the right observer/fake-quant constructor
104106
if is_qat:
105-
# TODO: qat + per channel?
106-
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
107-
elif is_per_channel:
108-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
107+
# Set plain fake-quant with true min/max
108+
weight_observer_or_fake_quant_ctr = FakeQuantize
109+
else:
110+
# PTQ: set min/max observer
111+
weight_observer_or_fake_quant_ctr = (
112+
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
113+
)
114+
115+
extra_args = {"eps": 2**-12}
109116

110-
extra_args: Dict[str, Any] = {"eps": 2**-12}
111-
if is_qat:
112-
if weight_qscheme == torch.per_tensor_symmetric:
113-
extra_args["observer"] = MovingAverageMinMaxObserver
114-
else:
115-
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
116117
weight_quantization_spec = QuantizationSpec(
117118
dtype=torch.int8,
118119
quant_min=weight_qmin,

0 commit comments

Comments
 (0)