Skip to content

Commit 1356118

Browse files
authored
Arm backend: Fix for combo neg(x)+1 + tests (#13517)
Fix for quantization error for combo neg(x) +1. Add more tests on combos with unary ops. Signed-off-by: Elena Zhelezina <[email protected]>
1 parent ec0b57b commit 1356118

File tree

2 files changed

+138
-3
lines changed

2 files changed

+138
-3
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,10 @@ def _match_pattern(
339339
torch.ops.aten.unflatten.int,
340340
torch.ops.aten.index_select.default,
341341
torch.ops.aten.index.Tensor,
342+
# Neg operator flips the range, but keps the magnitude the same.
343+
# That is why we force it to use the same qparams and avoid
344+
# dequant -> neg -> requant chain.
345+
torch.ops.aten.neg.default,
342346
]
343347

344348
_one_to_one_shared_input_or_input_act_qspec = [
@@ -540,9 +544,6 @@ def any_or_hardtanh_min_zero(n: Node):
540544
)
541545
]
542546
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
543-
elif node.target in (torch.ops.aten.neg.default,):
544-
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
545-
quant_properties.quant_output = _QuantProperty(0, input_act_qspec)
546547
elif node.target in _one_to_one:
547548
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
548549
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Tuple
6+
7+
import pytest
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
Tensor1 = Tuple[torch.Tensor]
20+
21+
22+
class NegAdd(torch.nn.Module):
23+
# neg(x) + 1
24+
edge_op_list = [
25+
"executorch_exir_dialects_edge__ops_aten_neg_default",
26+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
27+
]
28+
29+
def get_inputs(self) -> Tensor1:
30+
return (torch.rand(10, 10, 10),)
31+
32+
def forward(self, x):
33+
return torch.neg(x) + 1.0
34+
35+
36+
class MinAddZero(torch.nn.Module):
37+
# min(x, 0) + 1
38+
edge_op_list = [
39+
"executorch_exir_dialects_edge__ops_aten_full_like_default",
40+
"executorch_exir_dialects_edge__ops_aten_minimum_default",
41+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
42+
]
43+
44+
# range [-1, 1]
45+
def get_inputs(self) -> Tensor1:
46+
return (torch.rand(10, 10, 10) * 2 - 1,)
47+
48+
def forward(self, x):
49+
# We want Tensor-Tensor minimum
50+
z = torch.full_like(x, 0.0)
51+
return torch.minimum(x, z) + 1.0
52+
53+
54+
class MaxAddZero(torch.nn.Module):
55+
# max(x, 0) + 1.0
56+
edge_op_list = [
57+
"executorch_exir_dialects_edge__ops_aten_full_like_default",
58+
"executorch_exir_dialects_edge__ops_aten_maximum_default",
59+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
60+
]
61+
62+
# range [-1, 1]
63+
def get_inputs(self) -> Tensor1:
64+
return (torch.rand(10, 10, 10) * 2 - 1,)
65+
66+
def forward(self, x):
67+
z = torch.full_like(x, 0.0)
68+
return torch.maximum(x, z) + 1.0
69+
70+
71+
class AbsAdd(torch.nn.Module):
72+
# abs(x) + 1.0
73+
edge_op_list = [
74+
"executorch_exir_dialects_edge__ops_aten_abs_default",
75+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
76+
]
77+
78+
def get_inputs(self) -> Tensor1:
79+
return (torch.rand(10, 10, 10),)
80+
81+
def forward(self, x):
82+
return torch.abs(x) + 1.0
83+
84+
85+
MODELS = [NegAdd, AbsAdd, MaxAddZero, MinAddZero]
86+
87+
88+
def _build(model_cls):
89+
m = model_cls()
90+
return m, m.get_inputs(), model_cls.edge_op_list
91+
92+
93+
@pytest.mark.parametrize("model_cls", MODELS, ids=lambda c: c.__name__)
94+
def test_unary_combos_tosa_FP(model_cls):
95+
m, inputs, exir = _build(model_cls)
96+
p = TosaPipelineFP[Tensor1](m, inputs, aten_op=[], exir_op=exir)
97+
p.run()
98+
99+
100+
@pytest.mark.parametrize("model_cls", MODELS, ids=lambda c: c.__name__)
101+
def test_unary_combos_tosa_INT(model_cls):
102+
m, inputs, exir = _build(model_cls)
103+
p = TosaPipelineINT[Tensor1](m, inputs, aten_op=[], exir_op=exir, qtol=1)
104+
p.run()
105+
106+
107+
@common.XfailIfNoCorstone300
108+
@pytest.mark.parametrize("model_cls", MODELS, ids=lambda c: c.__name__)
109+
def test_unary_combos_u55_INT(model_cls):
110+
m, inputs, exir = _build(model_cls)
111+
p = EthosU55PipelineINT[Tensor1](
112+
m, inputs, aten_ops=[], exir_ops=exir, run_on_fvp=True
113+
)
114+
p.run()
115+
116+
117+
@common.XfailIfNoCorstone320
118+
@pytest.mark.parametrize("model_cls", MODELS, ids=lambda c: c.__name__)
119+
def test_unary_combos_u85_INT(model_cls):
120+
m, inputs, exir = _build(model_cls)
121+
p = EthosU85PipelineINT[Tensor1](
122+
m, inputs, aten_ops=[], exir_ops=exir, run_on_fvp=True
123+
)
124+
p.run()
125+
126+
127+
@common.SkipIfNoModelConverter
128+
@pytest.mark.parametrize("model_cls", MODELS, ids=lambda c: c.__name__)
129+
def test_unary_combos_vgf_INT(model_cls):
130+
m, inputs, exir = _build(model_cls)
131+
p = VgfPipeline[Tensor1](
132+
m, inputs, aten_op=[], exir_op=exir, tosa_version="TOSA-1.0+INT"
133+
)
134+
p.run()

0 commit comments

Comments
 (0)