diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 4518feeb403..9fa15568cc4 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -101,18 +101,20 @@ def get_symmetric_quantization_config( weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( MinMaxObserver ) + # Determine the right observer/fake-quant constructor if is_qat: - # Set plain fake-quant with true min/max - weight_observer_or_fake_quant_ctr = FakeQuantize + if is_per_channel: + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + else: + # Set plain fake-quant with true min/max + weight_observer_or_fake_quant_ctr = FakeQuantize else: # PTQ: set min/max observer weight_observer_or_fake_quant_ctr = ( PerChannelMinMaxObserver if is_per_channel else MinMaxObserver ) - extra_args = {"eps": 2**-12} - weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=weight_qmin, diff --git a/backends/arm/test/misc/test_bn_relu_folding_qat.py b/backends/arm/test/misc/test_bn_relu_folding_qat.py index c39c1694d0a..c88c38e869d 100644 --- a/backends/arm/test/misc/test_bn_relu_folding_qat.py +++ b/backends/arm/test/misc/test_bn_relu_folding_qat.py @@ -40,13 +40,17 @@ def forward(self, x: torch.Tensor): models = { - "conv_bn_relu": ConvModule(batch_norm=True), - "conv_relu": ConvModule(batch_norm=False), + # name : (model, is_per_channel) + "conv_bn_relu_per_channel": (ConvModule(batch_norm=True), True), + "conv_relu_per_channel": (ConvModule(batch_norm=False), True), + "conv_bn_relu_per_tensor": (ConvModule(batch_norm=True), False), + "conv_relu_per_tensor": (ConvModule(batch_norm=False), False), } -@common.parametrize("model", models) -def test_qat_tosa_INT(model: torch.nn.Module): +@common.parametrize("test_data", models) +def test_qat_tosa_INT(test_data): + model, per_channel = test_data pipeline = TosaPipelineINT[input_t1](model, model.test_data, [], [], qtol=1) tosa_version = conftest.get_option("tosa_version") tosa_profiles = { @@ -59,7 +63,7 @@ def test_qat_tosa_INT(model: torch.nn.Module): Quantize( quantizer=quantizer, quantization_config=get_symmetric_quantization_config( - is_qat=True, is_per_channel=False + is_qat=True, is_per_channel=per_channel ), is_qat=True, ),