Skip to content

Commit 4b1ae21

Browse files
Arm: Support ABS operator in Arm backend (#8459)
Support ABS operator in Arm backend
1 parent fea3684 commit 4b1ae21

File tree

5 files changed

+261
-0
lines changed

5 files changed

+261
-0
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class BaseTOSASupportList(OperatorSupportBase):
9191

9292
def is_node_supported(self, submodules, node: fx.Node) -> bool:
9393
supported = node.op == "call_function" and node.target in [
94+
exir_ops.edge.aten.abs.default,
9495
exir_ops.edge.aten.add.Tensor,
9596
exir_ops.edge.aten.expand_copy.default,
9697
exir_ops.edge.aten.cat.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import ( # noqa
99
node_visitor,
10+
op_abs,
1011
op_add,
1112
op_avg_pool2d,
1213
op_bmm,

backends/arm/operators/op_abs.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import executorch.backends.arm.tosa_quant_utils as tqutils
10+
import executorch.backends.arm.tosa_utils as tutils
11+
12+
import serializer.tosa_serializer as ts # type: ignore
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from executorch.backends.arm.tosa_specification import TosaSpecification
19+
20+
from serializer.tosa_serializer import TosaOp
21+
from torch.fx import Node
22+
23+
24+
@register_node_visitor
25+
class AbsVisitor_080_BI(NodeVisitor):
26+
target = "aten.abs.default"
27+
28+
tosa_specs = [
29+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
30+
]
31+
32+
def __init__(self, *args):
33+
super().__init__(*args)
34+
35+
def define_node(
36+
self,
37+
node: Node,
38+
tosa_graph: ts.TosaSerializer,
39+
inputs: List[TosaArg],
40+
output: TosaArg,
41+
) -> None:
42+
# Specification (0.80) states that input and output types
43+
# should all be the same
44+
if not (inputs[0].dtype == output.dtype):
45+
raise ValueError(
46+
"All inputs and outputs need same dtype."
47+
f"Got {inputs[0].dtype=}, {output.dtype=}"
48+
)
49+
# Handle int8 (quantized) and int32
50+
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
51+
raise ValueError(
52+
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
53+
)
54+
55+
if inputs[0].dtype == ts.DType.INT8:
56+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
57+
tosa_graph, inputs, node
58+
)
59+
else:
60+
# input[0].dtype == ts.DType.INT32
61+
# Non quantized input, natively support by TOSA.abs
62+
rescaled_inputs = inputs
63+
64+
if output.dtype == ts.DType.INT8:
65+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
66+
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
67+
else:
68+
# output.dtype == ts.DType.INT32
69+
abs_output = output
70+
71+
# Do the INT32 Abs
72+
tosa_graph.addOperator(
73+
TosaOp.Op().ABS,
74+
[
75+
rescaled_inputs[0].name,
76+
],
77+
[abs_output.name],
78+
None,
79+
)
80+
81+
if output.dtype == ts.DType.INT8:
82+
# Scale output back to 8 bit
83+
# pyre-ignore
84+
tqutils.insert_rescale_op_to_int8(tosa_graph, abs_output, scale_back, node) # type: ignore[possibly-undefined]
85+
86+
87+
@register_node_visitor
88+
class AbsVisitor_080_MI(AbsVisitor_080_BI):
89+
# inheriting 'target' from BI class
90+
91+
tosa_specs = [
92+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
93+
]
94+
95+
def __init__(self, *args):
96+
super().__init__(*args)
97+
98+
def define_node(
99+
self,
100+
node: Node,
101+
tosa_graph: ts.TosaSerializer,
102+
inputs: List[TosaArg],
103+
output: TosaArg,
104+
) -> None:
105+
# Specification (0.80) states that input and output types
106+
# should all be the same
107+
if not (inputs[0].dtype == output.dtype):
108+
raise ValueError(
109+
"All inputs and output need same dtype."
110+
f"Got {inputs[0].dtype=}, {output.dtype=}"
111+
)
112+
113+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
114+
# Call the inherited define_node for handling integers
115+
super().define_node(node, tosa_graph, inputs, output)
116+
else:
117+
# FP32 Abs lowering
118+
119+
if not (inputs[0].dtype == ts.DType.FP32):
120+
raise ValueError(
121+
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
122+
)
123+
124+
if not (output.dtype == ts.DType.FP32):
125+
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")
126+
127+
# MI lowering
128+
tosa_graph.addOperator(
129+
TosaOp.Op().ABS,
130+
[inputs[0].name],
131+
[output.name],
132+
None,
133+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def _match_pattern(
125125

126126

127127
_one_to_one = [
128+
torch.ops.aten.abs.default,
128129
torch.ops.aten.exp.default,
129130
torch.ops.aten.log.default,
130131
torch.ops.aten.reciprocal.default,

backends/arm/test/ops/test_abs.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2025 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
10+
from typing import Tuple
11+
12+
import pytest
13+
14+
import torch
15+
from executorch.backends.arm.test import common, conftest
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17+
from executorch.exir.backend.compile_spec_schema import CompileSpec
18+
from parameterized import parameterized
19+
20+
21+
class TestAbs(unittest.TestCase):
22+
class Abs(torch.nn.Module):
23+
test_parameters = [
24+
(torch.zeros(5),),
25+
(torch.full((5,), -1, dtype=torch.float32),),
26+
(torch.ones(5) * -1,),
27+
(torch.randn(8),),
28+
(torch.randn(2, 3, 4),),
29+
(torch.randn(1, 2, 3, 4),),
30+
(torch.normal(mean=0, std=10, size=(2, 3, 4)),),
31+
]
32+
33+
def forward(self, x):
34+
return torch.abs(x)
35+
36+
def _test_abs_tosa_MI_pipeline(
37+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
38+
):
39+
(
40+
ArmTester(
41+
module,
42+
example_inputs=test_data,
43+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
44+
)
45+
.export()
46+
.check_count({"torch.ops.aten.abs.default": 1})
47+
.check_not(["torch.ops.quantized_decomposed"])
48+
.to_edge()
49+
.partition()
50+
.check_not(["torch.ops.aten.abs.default"])
51+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
52+
.to_executorch()
53+
.run_method_and_compare_outputs(inputs=test_data)
54+
)
55+
56+
def _test_abs_tosa_BI_pipeline(
57+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
58+
):
59+
(
60+
ArmTester(
61+
module,
62+
example_inputs=test_data,
63+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
64+
)
65+
.quantize()
66+
.export()
67+
.check_count({"torch.ops.aten.abs.default": 1})
68+
.check(["torch.ops.quantized_decomposed"])
69+
.to_edge()
70+
.partition()
71+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
72+
.to_executorch()
73+
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
74+
)
75+
76+
def _test_abs_ethosu_BI_pipeline(
77+
self,
78+
compile_spec: list[CompileSpec],
79+
module: torch.nn.Module,
80+
test_data: Tuple[torch.Tensor],
81+
):
82+
tester = (
83+
ArmTester(
84+
module,
85+
example_inputs=test_data,
86+
compile_spec=compile_spec,
87+
)
88+
.quantize()
89+
.export()
90+
.check_count({"torch.ops.aten.abs.default": 1})
91+
.check(["torch.ops.quantized_decomposed"])
92+
.to_edge()
93+
.partition()
94+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
95+
.to_executorch()
96+
.serialize()
97+
)
98+
if conftest.is_option_enabled("corstone_fvp"):
99+
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
100+
101+
@parameterized.expand(Abs.test_parameters)
102+
def test_abs_tosa_MI(self, test_data: torch.Tensor):
103+
test_data = (test_data,)
104+
self._test_abs_tosa_MI_pipeline(self.Abs(), test_data)
105+
106+
@parameterized.expand(Abs.test_parameters)
107+
def test_abs_tosa_BI(self, test_data: torch.Tensor):
108+
test_data = (test_data,)
109+
self._test_abs_tosa_BI_pipeline(self.Abs(), test_data)
110+
111+
@parameterized.expand(Abs.test_parameters)
112+
@pytest.mark.corstone_fvp
113+
def test_abs_u55_BI(self, test_data: torch.Tensor):
114+
test_data = (test_data,)
115+
self._test_abs_ethosu_BI_pipeline(
116+
common.get_u55_compile_spec(), self.Abs(), test_data
117+
)
118+
119+
@parameterized.expand(Abs.test_parameters)
120+
@pytest.mark.corstone_fvp
121+
def test_abs_u85_BI(self, test_data: torch.Tensor):
122+
test_data = (test_data,)
123+
self._test_abs_ethosu_BI_pipeline(
124+
common.get_u85_compile_spec(), self.Abs(), test_data
125+
)

0 commit comments

Comments
 (0)