Skip to content

Commit 31e13b0

Browse files
Arm backend: Add support for QAT+per-channel combo (#13511)
Fixes a bug where an unsupported observer was used for QAT combined with per-channel quantization. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 1164fe4 commit 31e13b0

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,20 @@ def get_symmetric_quantization_config(
101101
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
102102
MinMaxObserver
103103
)
104+
104105
# Determine the right observer/fake-quant constructor
105106
if is_qat:
106-
# Set plain fake-quant with true min/max
107-
weight_observer_or_fake_quant_ctr = FakeQuantize
107+
if is_per_channel:
108+
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
109+
else:
110+
# Set plain fake-quant with true min/max
111+
weight_observer_or_fake_quant_ctr = FakeQuantize
108112
else:
109113
# PTQ: set min/max observer
110114
weight_observer_or_fake_quant_ctr = (
111115
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
112116
)
113117

114-
extra_args = {"eps": 2**-12}
115-
116118
weight_quantization_spec = QuantizationSpec(
117119
dtype=torch.int8,
118120
quant_min=weight_qmin,

backends/arm/test/misc/test_bn_relu_folding_qat.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,17 @@ def forward(self, x: torch.Tensor):
4040

4141

4242
models = {
43-
"conv_bn_relu": ConvModule(batch_norm=True),
44-
"conv_relu": ConvModule(batch_norm=False),
43+
# name : (model, is_per_channel)
44+
"conv_bn_relu_per_channel": (ConvModule(batch_norm=True), True),
45+
"conv_relu_per_channel": (ConvModule(batch_norm=False), True),
46+
"conv_bn_relu_per_tensor": (ConvModule(batch_norm=True), False),
47+
"conv_relu_per_tensor": (ConvModule(batch_norm=False), False),
4548
}
4649

4750

48-
@common.parametrize("model", models)
49-
def test_qat_tosa_INT(model: torch.nn.Module):
51+
@common.parametrize("test_data", models)
52+
def test_qat_tosa_INT(test_data):
53+
model, per_channel = test_data
5054
pipeline = TosaPipelineINT[input_t1](model, model.test_data, [], [], qtol=1)
5155
tosa_version = conftest.get_option("tosa_version")
5256
tosa_profiles = {
@@ -59,7 +63,7 @@ def test_qat_tosa_INT(model: torch.nn.Module):
5963
Quantize(
6064
quantizer=quantizer,
6165
quantization_config=get_symmetric_quantization_config(
62-
is_qat=True, is_per_channel=False
66+
is_qat=True, is_per_channel=per_channel
6367
),
6468
is_qat=True,
6569
),

0 commit comments

Comments
 (0)