Skip to content

Commit c87751e

Browse files
committed
Arm backend: Annotate ADD/SUB with indepenedent observers
We were annotating the ADD/SUB with a shared observer resulting in the same quantisation parameters on the two inputs even if we were adding numbers in different ranges(positive tensor to a tensor with positive and negative values). As a result, the quantisation parameters were suboptimal. This change annotates the operator with independent observers and changes how we rescale the two inputs to bring them to the same range. Added a unit test of a resnet model. Lowered the number of channels on a few unit tests in order to keep the Total SRAM Used below 2MB for the Ethos-U55 to fit within the memory limit of the Corstone-300. Change-Id: I7adde636f901c9df6b779d946a157e66fd12e24e
1 parent c003b8e commit c87751e

File tree

13 files changed

+358
-38
lines changed

13 files changed

+358
-38
lines changed

backends/arm/operators/op_add.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ def define_node(
5353
[ts.DType.INT8, ts.DType.INT32],
5454
output.tosa_spec,
5555
)
56-
5756
scale_back = 1.0
5857
if inputs[0].dtype == ts.DType.INT8:
59-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
58+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6059
tosa_graph, inputs, node, self.tosa_spec
6160
)
6261
else:
@@ -85,7 +84,12 @@ def define_node(
8584
# Scale output back to 8 bit
8685
# pyre-ignore
8786
tqutils.insert_rescale_op_to_int8(
88-
tosa_graph, add_output, scale_back, node, self.tosa_spec
87+
tosa_graph,
88+
add_output,
89+
scale_back,
90+
node,
91+
compute_rescale=False,
92+
tosa_spec=self.tosa_spec,
8993
) # type: ignore[possibly-undefined]
9094

9195

backends/arm/operators/op_sub.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def define_node(
5656

5757
scale_back = 1.0
5858
if inputs[0].dtype == ts.DType.INT8:
59-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
59+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6060
tosa_graph, inputs, node, self.tosa_spec
6161
)
6262
else:
@@ -86,7 +86,12 @@ def define_node(
8686
# Scale output back to 8 bit
8787
# pyre-ignore
8888
tqutils.insert_rescale_op_to_int8(
89-
tosa_graph, sub_output, scale_back, node, self.tosa_spec
89+
tosa_graph,
90+
sub_output,
91+
scale_back,
92+
node,
93+
compute_rescale=False,
94+
tosa_spec=self.tosa_spec,
9095
) # type: ignore[possibly-undefined]
9196

9297

backends/arm/quantizer/quantization_annotator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,10 @@ def any_or_hardtanh_min_zero(n: Node):
472472
]
473473
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
474474
elif node.target in (
475+
torch.ops.aten.add.Tensor,
476+
torch.ops.aten.add_.Tensor,
477+
torch.ops.aten.sub.Tensor,
478+
torch.ops.aten.sub_.Tensor,
475479
torch.ops.aten.matmul.default,
476480
torch.ops.aten.mm.default,
477481
torch.ops.aten.bmm.default,
@@ -484,10 +488,6 @@ def any_or_hardtanh_min_zero(n: Node):
484488
]
485489
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
486490
elif node.target in (
487-
torch.ops.aten.add.Tensor,
488-
torch.ops.aten.add_.Tensor,
489-
torch.ops.aten.sub.Tensor,
490-
torch.ops.aten.sub_.Tensor,
491491
torch.ops.aten.minimum.default,
492492
torch.ops.aten.maximum.default,
493493
):
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import pytest
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.backends.arm.test import common
13+
14+
from executorch.backends.arm.test.tester.test_pipeline import (
15+
EthosU55PipelineINT,
16+
EthosU85PipelineINT,
17+
TosaPipelineFP,
18+
TosaPipelineINT,
19+
)
20+
21+
22+
# Model with Conv1D - ReLU sequence and a residual add.
23+
# Testing the annotation of Conv1D-ReLU(to be fused) and annotation of add.
24+
# ReLU outputs positive numbers and linear outputs positive and negative numbers, so they
25+
# should have different quantisation parameters. If the ReLU gets wrong quantisation parameters(e.g. qmin!=zp)
26+
# because of a shared observer of a following operators(e.g. add), the Conv1D-ReLU sequence is not fused
27+
# and is left in FP32. As a result, the test fails.
28+
class AddDifferentRanges(torch.nn.Module):
29+
def __init__(self, in_channels, out_channels, kernel_size, input_dim):
30+
super().__init__()
31+
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size)
32+
self.relu = torch.nn.ReLU()
33+
self.linear = nn.Linear(out_channels, out_channels)
34+
35+
def forward(self, x):
36+
# Permute: (N, T, C) -> (N, C, T)
37+
x = x.permute(0, 2, 1)
38+
x = self.conv1(x)
39+
x = self.relu(x)
40+
x = x.permute(0, 2, 1)
41+
out = x + self.linear(x)
42+
return out
43+
44+
45+
input_t = Tuple[torch.Tensor]
46+
model = AddDifferentRanges(in_channels=3, out_channels=16, kernel_size=3, input_dim=10)
47+
model_inputs = (torch.randn(1, 10, 3),)
48+
quant_test_data = {
49+
"per_channel_quantization=true": True,
50+
"per_channel_quantization=false": False,
51+
}
52+
53+
54+
def test_tosa_FP():
55+
pipeline = TosaPipelineFP[input_t](
56+
model,
57+
model_inputs,
58+
aten_op=[],
59+
exir_op=[],
60+
use_to_edge_transform_and_lower=True,
61+
)
62+
pipeline.run()
63+
64+
65+
@common.parametrize("per_channel_quantization", quant_test_data)
66+
def test_tosa_INT(per_channel_quantization):
67+
pipeline = TosaPipelineINT[input_t](
68+
model,
69+
model_inputs,
70+
aten_op=[],
71+
exir_op=[],
72+
use_to_edge_transform_and_lower=True,
73+
per_channel_quantization=per_channel_quantization,
74+
qtol=0,
75+
)
76+
pipeline.run()
77+
78+
79+
@pytest.mark.slow
80+
@common.XfailIfNoCorstone300
81+
@common.parametrize("per_channel_quantization", quant_test_data)
82+
def test_tosa_u55_INT(per_channel_quantization):
83+
pipeline = EthosU55PipelineINT[input_t](
84+
model,
85+
model_inputs,
86+
[],
87+
[],
88+
run_on_fvp=True,
89+
use_to_edge_transform_and_lower=True,
90+
per_channel_quantization=per_channel_quantization,
91+
qtol=0,
92+
)
93+
pipeline.run()
94+
95+
96+
@pytest.mark.slow
97+
@common.XfailIfNoCorstone320
98+
@common.parametrize("per_channel_quantization", quant_test_data)
99+
def test_tosa_u85_INT(per_channel_quantization):
100+
pipeline = EthosU85PipelineINT[input_t](
101+
model,
102+
model_inputs,
103+
[],
104+
[],
105+
run_on_fvp=True,
106+
use_to_edge_transform_and_lower=True,
107+
per_channel_quantization=per_channel_quantization,
108+
qtol=0,
109+
)
110+
pipeline.run()

backends/arm/test/models/test_inception_v3_arm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_ic3_tosa_BI():
5151
aten_op=[],
5252
exir_op=[],
5353
use_to_edge_transform_and_lower=True,
54-
atol=0.6,
54+
atol=0.65,
5555
qtol=1,
5656
)
5757
pipeline.run()
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import pytest
9+
10+
import torch
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineINT,
14+
EthosU85PipelineINT,
15+
TosaPipelineFP,
16+
TosaPipelineINT,
17+
)
18+
19+
from torchvision import transforms # type: ignore[import-untyped]
20+
from torchvision.models import resnet18, ResNet18_Weights
21+
22+
model = resnet18(weights=ResNet18_Weights)
23+
model = model.eval()
24+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25+
26+
model_inputs = (normalize(torch.randn((1, 3, 224, 224))),)
27+
28+
input_t = Tuple[torch.Tensor]
29+
30+
31+
quant_test_data = {
32+
"per_channel_quantization=true": True,
33+
"per_channel_quantization=false": False,
34+
}
35+
36+
37+
def test_resnet_tosa_FP():
38+
pipeline = TosaPipelineFP[input_t](
39+
model,
40+
model_inputs,
41+
aten_op=[],
42+
exir_op=[],
43+
use_to_edge_transform_and_lower=True,
44+
)
45+
pipeline.run()
46+
47+
48+
@common.parametrize("per_channel_quantization", quant_test_data)
49+
def test_resnet_tosa_INT(per_channel_quantization):
50+
pipeline = TosaPipelineINT[input_t](
51+
model,
52+
model_inputs,
53+
aten_op=[],
54+
exir_op=[],
55+
use_to_edge_transform_and_lower=True,
56+
per_channel_quantization=per_channel_quantization,
57+
atol=0.5,
58+
qtol=1,
59+
)
60+
pipeline.run()
61+
62+
63+
@pytest.mark.slow
64+
@common.XfailIfNoCorstone300
65+
@common.parametrize("per_channel_quantization", quant_test_data)
66+
def test_resnet_u55_INT(per_channel_quantization):
67+
pipeline = EthosU55PipelineINT[input_t](
68+
model,
69+
model_inputs,
70+
aten_ops=[],
71+
exir_ops=[],
72+
run_on_fvp=True,
73+
use_to_edge_transform_and_lower=True,
74+
per_channel_quantization=per_channel_quantization,
75+
atol=0.5,
76+
qtol=1,
77+
)
78+
pipeline.run()
79+
80+
81+
@pytest.mark.slow
82+
@pytest.mark.xfail(
83+
reason="For resnet18 for Ethos-U85, the SRAM memory footprint is very high. The compiler team is investigating."
84+
)
85+
@common.XfailIfNoCorstone320
86+
@common.parametrize("per_channel_quantization", quant_test_data)
87+
def test_resnet_u85_INT(per_channel_quantization):
88+
pipeline = EthosU85PipelineINT[input_t](
89+
model,
90+
model_inputs,
91+
aten_ops=[],
92+
exir_ops=[],
93+
run_on_fvp=True,
94+
use_to_edge_transform_and_lower=True,
95+
per_channel_quantization=per_channel_quantization,
96+
atol=0.5,
97+
qtol=1,
98+
)
99+
pipeline.run()

backends/arm/test/ops/test_add.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
5858
"4d_randn_1": lambda: (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
5959
"4d_randn_2": lambda: (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
6060
"4d_randn_big": lambda: (
61-
10000 * torch.randn(1, 1, 4, 4),
61+
(1 << 30) * torch.randn(1, 1, 4, 4),
6262
torch.randn(1, 1, 4, 1),
6363
),
6464
"4d_randn_1_mutltiple_broadcasts": lambda: (
6565
torch.randn(1, 4, 4, 1),
6666
torch.ones(1, 1, 4, 4),
6767
),
68+
"4d_big_small": lambda: (
69+
(10e10) * torch.randn(1, 10, 20, 30),
70+
torch.randn(1, 10, 20, 30),
71+
),
6872
}
6973

7074

@@ -87,7 +91,7 @@ def test_add_tensor_tosa_FP(test_data: input_t1):
8791

8892
@common.parametrize("test_data", Add.test_data)
8993
def test_add_tensor_tosa_INT(test_data: input_t1):
90-
pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op)
94+
pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op, qtol=0)
9195
pipeline.run()
9296

9397

@@ -112,9 +116,16 @@ def test_add_tensor_tosa_INT_i32(test_data: input_t1):
112116
quant_max=2**31 - 1,
113117
quant_min=-(2**31),
114118
)
119+
output_act_qspec = QuantizationSpec(
120+
torch.int32,
121+
observer,
122+
qscheme=torch.per_tensor_symmetric,
123+
quant_max=2**31 - 1,
124+
quant_min=-(2**31),
125+
)
115126
# This quantization_config will be set as global config.
116127
quantization_config = arm_quantizer.QuantizationConfig(
117-
input_act_qspec, None, None, None
128+
input_act_qspec, output_act_qspec, None, None
118129
)
119130
quantize_stage = Quantize(quantizer, quantization_config)
120131
pipeline.change_args("quantize", quantize_stage)
@@ -158,13 +169,13 @@ def test_add_tensor_tosa_FP_3(test_data: input_t2):
158169

159170
@common.parametrize("test_data", Add3.test_data)
160171
def test_add_tensor_tosa_INT_3(test_data: input_t2):
161-
pipeline = TosaPipelineINT[input_t2](Add3(), test_data(), aten_op, exir_op)
172+
pipeline = TosaPipelineINT[input_t2](Add3(), test_data(), aten_op, exir_op, qtol=0)
162173
pipeline.run()
163174

164175

165176
@common.parametrize("test_data", Add2.test_data)
166177
def test_add_tensor_tosa_INT_2(test_data: input_t2):
167-
pipeline = TosaPipelineINT[input_t2](Add2(), test_data(), aten_op, exir_op)
178+
pipeline = TosaPipelineINT[input_t2](Add2(), test_data(), aten_op, exir_op, qtol=0)
168179
pipeline.run()
169180

170181

backends/arm/test/ops/test_conv_combos.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,28 @@ def __init__(self):
4747
# (t, c, n, s) = (6, 96, 1, 1)
4848
# 1. 1x1 CONV2d + ReLU6 (Pointwise)
4949
self.pointwise_conv2d = torch.nn.Conv2d(
50-
in_channels=32, out_channels=128, kernel_size=1, stride=1, groups=1
50+
in_channels=16, out_channels=96, kernel_size=1, stride=1, groups=1
5151
) ## (1, 128, 81, 81)
52-
self.batch_norm2d_16 = torch.nn.BatchNorm2d(128, affine=False)
52+
self.batch_norm2d_16 = torch.nn.BatchNorm2d(96, affine=False)
5353
self.relu6 = torch.nn.ReLU6()
5454

5555
# 2. 3x3 DepthwiseConv2d + ReLu6
5656
self.depthwise_conv2d = torch.nn.Conv2d(
57-
in_channels=128,
58-
out_channels=128,
57+
in_channels=96,
58+
out_channels=96,
5959
kernel_size=3,
6060
padding=1,
6161
stride=1,
62-
groups=128,
62+
groups=96,
6363
) ## (1, 128, H, W)
6464

6565
# 3. Linear 1x1 Conv2d
6666
self.pointwise_conv2d_linear = torch.nn.Conv2d(
67-
in_channels=128, out_channels=32, kernel_size=1, stride=1, groups=1
67+
in_channels=96, out_channels=16, kernel_size=1, stride=1, groups=1
6868
) ## (1, 32, 81, 81)
6969

7070
def get_inputs(self) -> Tuple[torch.Tensor]:
71-
return (torch.randn(1, 32, 81, 81),)
71+
return (torch.randn(1, 16, 81, 81),)
7272

7373
def forward(self, x):
7474
input = x

backends/arm/test/ops/test_group_norm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def test_native_group_norm_tosa_INT(test_data):
102102
"test_data",
103103
test_data_suite,
104104
xfails={
105+
"rand_4_6_8_groups_2_eps_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
106+
"rand_4_6_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
107+
"rand_4_6_groups_2": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
105108
"randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
106109
"rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
107110
"rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",

0 commit comments

Comments
 (0)