Skip to content

Commit 8ba92a9

Browse files
gggekovdigantdesai
andauthored
Arm backend: Annotate ADD/SUB with indepenedent observers (#13516)
We were annotating the ADD/SUB with a shared observer resulting in the same quantisation parameters on the two inputs even if we were adding numbers in different ranges(positive tensor to a tensor with positive and negative values). As a result, the quantisation parameters were suboptimal. This change annotates the operator with independent observers and changes how we rescale the two inputs to bring them to the same range. Added a unit test of a resnet model. Lowered the number of channels on a few unit tests in order to keep the Total SRAM Used below 2MB for the Ethos-U55 to fit within the memory limit of the Corstone-300. Fixes #12959 Co-authored-by: Digant Desai <[email protected]>
1 parent 899d7e5 commit 8ba92a9

13 files changed

+358
-38
lines changed

backends/arm/operators/op_add.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ def define_node(
5353
[ts.DType.INT8, ts.DType.INT32],
5454
output.tosa_spec,
5555
)
56-
5756
scale_back = 1.0
5857
if inputs[0].dtype == ts.DType.INT8:
59-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
58+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6059
tosa_graph, inputs, node, self.tosa_spec
6160
)
6261
else:
@@ -85,7 +84,12 @@ def define_node(
8584
# Scale output back to 8 bit
8685
# pyre-ignore
8786
tqutils.insert_rescale_op_to_int8(
88-
tosa_graph, add_output, scale_back, node, self.tosa_spec
87+
tosa_graph,
88+
add_output,
89+
scale_back,
90+
node,
91+
compute_rescale=False,
92+
tosa_spec=self.tosa_spec,
8993
) # type: ignore[possibly-undefined]
9094

9195

backends/arm/operators/op_sub.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def define_node(
5656

5757
scale_back = 1.0
5858
if inputs[0].dtype == ts.DType.INT8:
59-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
59+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6060
tosa_graph, inputs, node, self.tosa_spec
6161
)
6262
else:
@@ -86,7 +86,12 @@ def define_node(
8686
# Scale output back to 8 bit
8787
# pyre-ignore
8888
tqutils.insert_rescale_op_to_int8(
89-
tosa_graph, sub_output, scale_back, node, self.tosa_spec
89+
tosa_graph,
90+
sub_output,
91+
scale_back,
92+
node,
93+
compute_rescale=False,
94+
tosa_spec=self.tosa_spec,
9095
) # type: ignore[possibly-undefined]
9196

9297

backends/arm/quantizer/quantization_annotator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,10 @@ def any_or_hardtanh_min_zero(n: Node):
473473
]
474474
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
475475
elif node.target in (
476+
torch.ops.aten.add.Tensor,
477+
torch.ops.aten.add_.Tensor,
478+
torch.ops.aten.sub.Tensor,
479+
torch.ops.aten.sub_.Tensor,
476480
torch.ops.aten.matmul.default,
477481
torch.ops.aten.mm.default,
478482
torch.ops.aten.bmm.default,
@@ -485,10 +489,6 @@ def any_or_hardtanh_min_zero(n: Node):
485489
]
486490
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
487491
elif node.target in (
488-
torch.ops.aten.add.Tensor,
489-
torch.ops.aten.add_.Tensor,
490-
torch.ops.aten.sub.Tensor,
491-
torch.ops.aten.sub_.Tensor,
492492
torch.ops.aten.minimum.default,
493493
torch.ops.aten.maximum.default,
494494
):
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import pytest
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.backends.arm.test import common
13+
14+
from executorch.backends.arm.test.tester.test_pipeline import (
15+
EthosU55PipelineINT,
16+
EthosU85PipelineINT,
17+
TosaPipelineFP,
18+
TosaPipelineINT,
19+
)
20+
21+
22+
# Model with Conv1D - ReLU sequence and a residual add.
23+
# Testing the annotation of Conv1D-ReLU(to be fused) and annotation of add.
24+
# ReLU outputs positive numbers and linear outputs positive and negative numbers, so they
25+
# should have different quantisation parameters. If the ReLU gets wrong quantisation parameters(e.g. qmin!=zp)
26+
# because of a shared observer of a following operators(e.g. add), the Conv1D-ReLU sequence is not fused
27+
# and is left in FP32. As a result, the test fails.
28+
class AddDifferentRanges(torch.nn.Module):
29+
def __init__(self, in_channels, out_channels, kernel_size, input_dim):
30+
super().__init__()
31+
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size)
32+
self.relu = torch.nn.ReLU()
33+
self.linear = nn.Linear(out_channels, out_channels)
34+
35+
def forward(self, x):
36+
# Permute: (N, T, C) -> (N, C, T)
37+
x = x.permute(0, 2, 1)
38+
x = self.conv1(x)
39+
x = self.relu(x)
40+
x = x.permute(0, 2, 1)
41+
out = x + self.linear(x)
42+
return out
43+
44+
45+
input_t = Tuple[torch.Tensor]
46+
model = AddDifferentRanges(in_channels=3, out_channels=16, kernel_size=3, input_dim=10)
47+
model_inputs = (torch.randn(1, 10, 3),)
48+
quant_test_data = {
49+
"per_channel_quantization=true": True,
50+
"per_channel_quantization=false": False,
51+
}
52+
53+
54+
def test_tosa_FP():
55+
pipeline = TosaPipelineFP[input_t](
56+
model,
57+
model_inputs,
58+
aten_op=[],
59+
exir_op=[],
60+
use_to_edge_transform_and_lower=True,
61+
)
62+
pipeline.run()
63+
64+
65+
@common.parametrize("per_channel_quantization", quant_test_data)
66+
def test_tosa_INT(per_channel_quantization):
67+
pipeline = TosaPipelineINT[input_t](
68+
model,
69+
model_inputs,
70+
aten_op=[],
71+
exir_op=[],
72+
use_to_edge_transform_and_lower=True,
73+
per_channel_quantization=per_channel_quantization,
74+
qtol=0,
75+
)
76+
pipeline.run()
77+
78+
79+
@pytest.mark.slow
80+
@common.XfailIfNoCorstone300
81+
@common.parametrize("per_channel_quantization", quant_test_data)
82+
def test_tosa_u55_INT(per_channel_quantization):
83+
pipeline = EthosU55PipelineINT[input_t](
84+
model,
85+
model_inputs,
86+
[],
87+
[],
88+
run_on_fvp=True,
89+
use_to_edge_transform_and_lower=True,
90+
per_channel_quantization=per_channel_quantization,
91+
qtol=0,
92+
)
93+
pipeline.run()
94+
95+
96+
@pytest.mark.slow
97+
@common.XfailIfNoCorstone320
98+
@common.parametrize("per_channel_quantization", quant_test_data)
99+
def test_tosa_u85_INT(per_channel_quantization):
100+
pipeline = EthosU85PipelineINT[input_t](
101+
model,
102+
model_inputs,
103+
[],
104+
[],
105+
run_on_fvp=True,
106+
use_to_edge_transform_and_lower=True,
107+
per_channel_quantization=per_channel_quantization,
108+
qtol=0,
109+
)
110+
pipeline.run()

backends/arm/test/models/test_inception_v3_arm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_ic3_tosa_BI():
5151
aten_op=[],
5252
exir_op=[],
5353
use_to_edge_transform_and_lower=True,
54-
atol=0.6,
54+
atol=0.65,
5555
qtol=1,
5656
)
5757
pipeline.run()
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import pytest
9+
10+
import torch
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineINT,
14+
EthosU85PipelineINT,
15+
TosaPipelineFP,
16+
TosaPipelineINT,
17+
)
18+
19+
from torchvision import transforms # type: ignore[import-untyped]
20+
from torchvision.models import resnet18, ResNet18_Weights
21+
22+
model = resnet18(weights=ResNet18_Weights)
23+
model = model.eval()
24+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25+
26+
model_inputs = (normalize(torch.randn((1, 3, 224, 224))),)
27+
28+
input_t = Tuple[torch.Tensor]
29+
30+
31+
quant_test_data = {
32+
"per_channel_quantization=true": True,
33+
"per_channel_quantization=false": False,
34+
}
35+
36+
37+
def test_resnet_tosa_FP():
38+
pipeline = TosaPipelineFP[input_t](
39+
model,
40+
model_inputs,
41+
aten_op=[],
42+
exir_op=[],
43+
use_to_edge_transform_and_lower=True,
44+
)
45+
pipeline.run()
46+
47+
48+
@common.parametrize("per_channel_quantization", quant_test_data)
49+
def test_resnet_tosa_INT(per_channel_quantization):
50+
pipeline = TosaPipelineINT[input_t](
51+
model,
52+
model_inputs,
53+
aten_op=[],
54+
exir_op=[],
55+
use_to_edge_transform_and_lower=True,
56+
per_channel_quantization=per_channel_quantization,
57+
atol=0.5,
58+
qtol=1,
59+
)
60+
pipeline.run()
61+
62+
63+
@pytest.mark.slow
64+
@common.XfailIfNoCorstone300
65+
@common.parametrize("per_channel_quantization", quant_test_data)
66+
def test_resnet_u55_INT(per_channel_quantization):
67+
pipeline = EthosU55PipelineINT[input_t](
68+
model,
69+
model_inputs,
70+
aten_ops=[],
71+
exir_ops=[],
72+
run_on_fvp=True,
73+
use_to_edge_transform_and_lower=True,
74+
per_channel_quantization=per_channel_quantization,
75+
atol=0.5,
76+
qtol=1,
77+
)
78+
pipeline.run()
79+
80+
81+
@pytest.mark.slow
82+
@pytest.mark.xfail(
83+
reason="For resnet18 for Ethos-U85, the SRAM memory footprint is very high. The compiler team is investigating."
84+
)
85+
@common.XfailIfNoCorstone320
86+
@common.parametrize("per_channel_quantization", quant_test_data)
87+
def test_resnet_u85_INT(per_channel_quantization):
88+
pipeline = EthosU85PipelineINT[input_t](
89+
model,
90+
model_inputs,
91+
aten_ops=[],
92+
exir_ops=[],
93+
run_on_fvp=True,
94+
use_to_edge_transform_and_lower=True,
95+
per_channel_quantization=per_channel_quantization,
96+
atol=0.5,
97+
qtol=1,
98+
)
99+
pipeline.run()

backends/arm/test/ops/test_add.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
5757
"4d_randn_1": lambda: (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
5858
"4d_randn_2": lambda: (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
5959
"4d_randn_big": lambda: (
60-
10000 * torch.randn(1, 1, 4, 4),
60+
(1 << 30) * torch.randn(1, 1, 4, 4),
6161
torch.randn(1, 1, 4, 1),
6262
),
6363
"4d_randn_1_mutltiple_broadcasts": lambda: (
6464
torch.randn(1, 4, 4, 1),
6565
torch.ones(1, 1, 4, 4),
6666
),
67+
"4d_big_small": lambda: (
68+
(10e10) * torch.randn(1, 10, 20, 30),
69+
torch.randn(1, 10, 20, 30),
70+
),
6771
}
6872

6973

@@ -86,7 +90,7 @@ def test_add_tensor_tosa_FP(test_data: input_t1):
8690

8791
@common.parametrize("test_data", Add.test_data)
8892
def test_add_tensor_tosa_INT(test_data: input_t1):
89-
pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op)
93+
pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op, qtol=0)
9094
pipeline.run()
9195

9296

@@ -111,9 +115,16 @@ def test_add_tensor_tosa_INT_i32(test_data: input_t1):
111115
quant_max=2**31 - 1,
112116
quant_min=-(2**31),
113117
)
118+
output_act_qspec = QuantizationSpec(
119+
torch.int32,
120+
observer,
121+
qscheme=torch.per_tensor_symmetric,
122+
quant_max=2**31 - 1,
123+
quant_min=-(2**31),
124+
)
114125
# This quantization_config will be set as global config.
115126
quantization_config = arm_quantizer.QuantizationConfig(
116-
input_act_qspec, None, None, None
127+
input_act_qspec, output_act_qspec, None, None
117128
)
118129
quantize_stage = Quantize(quantizer, quantization_config)
119130
pipeline.change_args("quantize", quantize_stage)
@@ -157,13 +168,13 @@ def test_add_tensor_tosa_FP_3(test_data: input_t2):
157168

158169
@common.parametrize("test_data", Add3.test_data)
159170
def test_add_tensor_tosa_INT_3(test_data: input_t2):
160-
pipeline = TosaPipelineINT[input_t2](Add3(), test_data(), aten_op, exir_op)
171+
pipeline = TosaPipelineINT[input_t2](Add3(), test_data(), aten_op, exir_op, qtol=0)
161172
pipeline.run()
162173

163174

164175
@common.parametrize("test_data", Add2.test_data)
165176
def test_add_tensor_tosa_INT_2(test_data: input_t2):
166-
pipeline = TosaPipelineINT[input_t2](Add2(), test_data(), aten_op, exir_op)
177+
pipeline = TosaPipelineINT[input_t2](Add2(), test_data(), aten_op, exir_op, qtol=0)
167178
pipeline.run()
168179

169180

backends/arm/test/ops/test_conv_combos.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,28 @@ def __init__(self):
4747
# (t, c, n, s) = (6, 96, 1, 1)
4848
# 1. 1x1 CONV2d + ReLU6 (Pointwise)
4949
self.pointwise_conv2d = torch.nn.Conv2d(
50-
in_channels=32, out_channels=128, kernel_size=1, stride=1, groups=1
50+
in_channels=16, out_channels=96, kernel_size=1, stride=1, groups=1
5151
) ## (1, 128, 81, 81)
52-
self.batch_norm2d_16 = torch.nn.BatchNorm2d(128, affine=False)
52+
self.batch_norm2d_16 = torch.nn.BatchNorm2d(96, affine=False)
5353
self.relu6 = torch.nn.ReLU6()
5454

5555
# 2. 3x3 DepthwiseConv2d + ReLu6
5656
self.depthwise_conv2d = torch.nn.Conv2d(
57-
in_channels=128,
58-
out_channels=128,
57+
in_channels=96,
58+
out_channels=96,
5959
kernel_size=3,
6060
padding=1,
6161
stride=1,
62-
groups=128,
62+
groups=96,
6363
) ## (1, 128, H, W)
6464

6565
# 3. Linear 1x1 Conv2d
6666
self.pointwise_conv2d_linear = torch.nn.Conv2d(
67-
in_channels=128, out_channels=32, kernel_size=1, stride=1, groups=1
67+
in_channels=96, out_channels=16, kernel_size=1, stride=1, groups=1
6868
) ## (1, 32, 81, 81)
6969

7070
def get_inputs(self) -> Tuple[torch.Tensor]:
71-
return (torch.randn(1, 32, 81, 81),)
71+
return (torch.randn(1, 16, 81, 81),)
7272

7373
def forward(self, x):
7474
input = x

backends/arm/test/ops/test_group_norm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def test_native_group_norm_tosa_INT(test_data):
102102
"test_data",
103103
test_data_suite,
104104
xfails={
105+
"rand_4_6_8_groups_2_eps_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
106+
"rand_4_6_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
107+
"rand_4_6_groups_2": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
105108
"randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
106109
"rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
107110
"rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",

0 commit comments

Comments
 (0)