Skip to content

Commit c3567b4

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Make per-channel quantization default
Support for per-channel quantization was recently added to the Arm backend. This patch changes the default setting to use per-channel quantization for weights in convolutional and linear layers, instead of per-tensor quantization, which was the previous default. The reason for this change is that per-channel quantization offers better numerical accuracy for models containing convolutional and/or fully connected layers. Unless there is an explicit limitation in the use case that prevents the use of per-channel quantization, it is generally preferred. The option to set quantization granularity can still be manually set using `get_symmetric_quantization_config(is_per_channel=False)`. This patch only changes the default. Unit and model tests are affected by this change. Error tolerances for those tests have not been changed, as model outputs are compared against a reference that uses the exact same quantization strategy. That is, if a model output is altered by this patch, the reference it is compared against would also be altered accordingly. To verify the impact of this change in terms of top-1 and top-5 accuracy, a manual test was run on MobileNetV2. The results show a noticeable improvement: - Per-tensor quantization Top-1 / Top-5 accuracy: 66.45% / 87.50% - Per-channel quantization Top-1 / Top-5 accuracy: 70.85% / 89.50% Change-Id: I35d5c62741c7f93b916560874689245db96a588b
1 parent a2b620a commit c3567b4

File tree

8 files changed

+46
-45
lines changed

8 files changed

+46
-45
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
@functools.lru_cache
6262
def get_symmetric_quantization_config(
63-
is_per_channel: bool = False,
63+
is_per_channel: bool = True,
6464
is_qat: bool = False,
6565
is_dynamic: bool = False,
6666
act_qmin: int = -128,

backends/arm/test/misc/test_bn_relu_folding_qat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def test_qat_tosa_BI(model: torch.nn.Module):
5959
"quantize",
6060
Quantize(
6161
quantizer=quantizer,
62-
quantization_config=get_symmetric_quantization_config(is_qat=True),
62+
quantization_config=get_symmetric_quantization_config(
63+
is_qat=True, is_per_channel=False
64+
),
6365
is_qat=True,
6466
),
6567
)

backends/arm/test/models/test_mobilenet_v2_arm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def test_mv2_tosa_BI():
4646
aten_op=[],
4747
exir_op=[],
4848
use_to_edge_transform_and_lower=True,
49-
per_channel_quantization=True,
5049
atol=0.25,
5150
qtol=1,
5251
)
@@ -63,7 +62,6 @@ def test_mv2_u55_BI():
6362
exir_ops=[],
6463
run_on_fvp=True,
6564
use_to_edge_transform_and_lower=True,
66-
per_channel_quantization=True,
6765
atol=0.25,
6866
qtol=1,
6967
)
@@ -80,7 +78,6 @@ def test_mv2_u85_BI():
8078
exir_ops=[],
8179
run_on_fvp=True,
8280
use_to_edge_transform_and_lower=True,
83-
per_channel_quantization=True,
8481
atol=0.25,
8582
qtol=1,
8683
)

backends/arm/test/ops/test_conv2d.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,10 @@ def forward(self, x):
384384
"2x2_3x2x40x40_nobias": "MLETORCH-520: Numerical issues on FVP.",
385385
"5x5_3x2x128x128_st1": "MLETORCH-520: Numerical issues on FVP.",
386386
}
387+
per_channel_quant_xfails = {
388+
"groups": "MLETORCH-1144: Invalid TOSA Graph",
389+
"groups_bias": "MLETORCH-1144: Invalid TOSA Graph",
390+
}
387391
input_t = Tuple[torch.Tensor]
388392

389393

@@ -398,7 +402,7 @@ def test_convolution_2d_tosa_MI(test_module):
398402
pipeline.run()
399403

400404

401-
@common.parametrize("test_module", test_modules)
405+
@common.parametrize("test_module", test_modules, per_channel_quant_xfails)
402406
def test_convolution_2d_tosa_BI(test_module):
403407
pipeline = TosaPipelineBI[input_t](
404408
test_module(),
@@ -410,7 +414,7 @@ def test_convolution_2d_tosa_BI(test_module):
410414
pipeline.run()
411415

412416

413-
@common.parametrize("test_module", test_modules, fvp_xfails)
417+
@common.parametrize("test_module", test_modules, fvp_xfails | per_channel_quant_xfails)
414418
@common.XfailIfNoCorstone300
415419
def test_convolution_2d_u55_BI(test_module):
416420
pipeline = EthosU55PipelineBI[input_t](
@@ -423,7 +427,7 @@ def test_convolution_2d_u55_BI(test_module):
423427
pipeline.run()
424428

425429

426-
@common.parametrize("test_module", test_modules, fvp_xfails)
430+
@common.parametrize("test_module", test_modules, fvp_xfails | per_channel_quant_xfails)
427431
@common.XfailIfNoCorstone320
428432
def test_convolution_2d_u85_BI(test_module):
429433
pipeline = EthosU85PipelineBI[input_t](

backends/arm/test/ops/test_multihead_attention.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@ def test_multihead_attention_tosa_MI(test_data: input_t1):
5353
)
5454
def test_multihead_attention_tosa_BI(test_data):
5555
test_data, module = test_data()
56-
pipeline = TosaPipelineBI(module, (*test_data, *test_data, *test_data), [], [])
56+
pipeline = TosaPipelineBI(
57+
module,
58+
(*test_data, *test_data, *test_data),
59+
[],
60+
[],
61+
# TODO: Per-channel quantization is broken (MLETORCH-1144)
62+
per_channel_quantization=False,
63+
)
5764
pipeline.run()
5865

5966

backends/arm/test/tester/test_pipeline.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def __init__(
299299
run_on_tosa_ref_model: bool = True,
300300
tosa_version: str = "TOSA-0.80+BI",
301301
symmetric_io_quantization: bool = False,
302-
per_channel_quantization: bool = False,
302+
per_channel_quantization: bool = True,
303303
use_to_edge_transform_and_lower: bool = True,
304304
custom_path: str = None,
305305
atol: float = 1e-03,
@@ -316,16 +316,14 @@ def __init__(
316316
compile_spec = common.get_tosa_compile_spec(
317317
tosa_profiles[tosa_version], custom_path=custom_path
318318
)
319-
if symmetric_io_quantization or per_channel_quantization:
320-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
321-
quantization_config = get_symmetric_quantization_config(
322-
is_per_channel=per_channel_quantization
323-
)
324-
if symmetric_io_quantization:
325-
quantizer.set_io(quantization_config)
326-
quant_stage = Quantize(quantizer, quantization_config)
327-
else:
328-
quant_stage = None
319+
320+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
321+
quantization_config = get_symmetric_quantization_config(
322+
is_per_channel=per_channel_quantization
323+
)
324+
if symmetric_io_quantization:
325+
quantizer.set_io(quantization_config)
326+
quant_stage = Quantize(quantizer, quantization_config)
329327

330328
super().__init__(
331329
module,
@@ -474,24 +472,21 @@ def __init__(
474472
exir_ops: Optional[str | List[str]] = None,
475473
run_on_fvp: bool = True,
476474
symmetric_io_quantization: bool = False,
477-
per_channel_quantization: bool = False,
475+
per_channel_quantization: bool = True,
478476
use_to_edge_transform_and_lower: bool = True,
479477
custom_path: str = None,
480478
atol: float = 1e-03,
481479
rtol: float = 1e-03,
482480
qtol: int = 1,
483481
):
484482
compile_spec = common.get_u55_compile_spec(custom_path=custom_path)
485-
if symmetric_io_quantization or per_channel_quantization:
486-
quantizer = EthosUQuantizer(compile_spec)
487-
quantization_config = get_symmetric_quantization_config(
488-
is_per_channel=per_channel_quantization
489-
)
490-
if symmetric_io_quantization:
491-
quantizer.set_io(quantization_config)
492-
quant_stage = Quantize(quantizer, quantization_config)
493-
else:
494-
quant_stage = None
483+
quantizer = EthosUQuantizer(compile_spec)
484+
quantization_config = get_symmetric_quantization_config(
485+
is_per_channel=per_channel_quantization
486+
)
487+
if symmetric_io_quantization:
488+
quantizer.set_io(quantization_config)
489+
quant_stage = Quantize(quantizer, quantization_config)
495490

496491
super().__init__(
497492
module,
@@ -564,24 +559,21 @@ def __init__(
564559
exir_ops: str | List[str] = None,
565560
run_on_fvp: bool = True,
566561
symmetric_io_quantization: bool = False,
567-
per_channel_quantization: bool = False,
562+
per_channel_quantization: bool = True,
568563
use_to_edge_transform_and_lower: bool = True,
569564
custom_path: str = None,
570565
atol: float = 1e-03,
571566
rtol: float = 1e-03,
572567
qtol: int = 1,
573568
):
574569
compile_spec = common.get_u85_compile_spec(custom_path=custom_path)
575-
if symmetric_io_quantization or per_channel_quantization:
576-
quantizer = EthosUQuantizer(compile_spec)
577-
quantization_config = get_symmetric_quantization_config(
578-
is_per_channel=per_channel_quantization
579-
)
580-
if symmetric_io_quantization:
581-
quantizer.set_io(quantization_config)
582-
quant_stage = Quantize(quantizer, quantization_config)
583-
else:
584-
quant_stage = None
570+
quantizer = EthosUQuantizer(compile_spec)
571+
quantization_config = get_symmetric_quantization_config(
572+
is_per_channel=per_channel_quantization
573+
)
574+
if symmetric_io_quantization:
575+
quantizer.set_io(quantization_config)
576+
quant_stage = Quantize(quantizer, quantization_config)
585577

586578
super().__init__(
587579
module,

examples/arm/aot_arm_compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ def quantize(
160160
else:
161161
raise RuntimeError("Unsupported compilespecs for quantization!")
162162

163-
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
164-
operator_config = get_symmetric_quantization_config(is_per_channel=False)
163+
operator_config = get_symmetric_quantization_config()
165164
quantizer.set_global(operator_config)
166165
m = prepare_pt2e(model, quantizer)
167166

examples/arm/ethos_u_minimal_example.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"\n",
102102
"# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n",
103103
"quantizer = EthosUQuantizer(compile_spec)\n",
104-
"operator_config = get_symmetric_quantization_config(is_per_channel=False)\n",
104+
"operator_config = get_symmetric_quantization_config()\n",
105105
"quantizer.set_global(operator_config)\n",
106106
"\n",
107107
"# Post training quantization\n",

0 commit comments

Comments
 (0)