Skip to content

Commit cb0f53e

Browse files
authored
Arm backend: Add tanh operator
Differential Revision: D64427390 Pull Request resolved: #6226
1 parent 8c96805 commit cb0f53e

File tree

5 files changed

+223
-0
lines changed

5 files changed

+223
-0
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6565
exir_ops.edge.aten.slice_copy.Tensor,
6666
exir_ops.edge.aten.sub.Tensor,
6767
exir_ops.edge.aten.sum.dim_IntList,
68+
exir_ops.edge.aten.tanh.default,
6869
exir_ops.edge.aten.view_copy.default,
6970
exir_ops.edge.aten.clone.default,
7071
exir_ops.edge.aten.mean.dim,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
op_squeeze,
3535
op_sub,
3636
op_sum,
37+
op_tanh,
3738
op_transpose,
3839
op_unsqueeze,
3940
op_view,

backends/arm/operators/op_tanh.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import numpy as np
10+
11+
import serializer.tosa_serializer as ts
12+
from executorch.backends.arm.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.arm.tosa_mapping import TosaArg
17+
18+
from executorch.backends.arm.tosa_quant_utils import (
19+
dequantize_value,
20+
get_quant_node_args,
21+
QuantArgs,
22+
quantize_value,
23+
)
24+
from serializer.tosa_serializer import TosaOp
25+
from torch.fx import Node
26+
27+
28+
@register_node_visitor
29+
class TanhVisitor(NodeVisitor):
30+
target = "aten.tanh.default"
31+
32+
def __init__(self, *args):
33+
super().__init__(*args)
34+
35+
def define_node(
36+
self,
37+
node: Node,
38+
tosa_graph: ts.TosaSerializer,
39+
inputs: List[TosaArg],
40+
output: TosaArg,
41+
is_quant_node: bool,
42+
) -> None:
43+
44+
assert len(node.all_input_nodes) == 1
45+
46+
if is_quant_node:
47+
# Assume quantized input is 8 bit.
48+
assert len(node.users) == 1
49+
50+
# Create attribute for 8 bit table lookup.
51+
input_node = node.all_input_nodes[0]
52+
in_quantargs = get_quant_node_args(input_node)
53+
output_node = list(node.users)[0]
54+
out_quantargs = get_quant_node_args(output_node)
55+
56+
table = tanh_table_8bit(in_quantargs, out_quantargs)
57+
table_attr = ts.TosaSerializerAttribute()
58+
table_attr.TableAttribute(table)
59+
60+
tosa_graph.addOperator(
61+
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
62+
)
63+
else:
64+
tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name])
65+
66+
67+
def tanh_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
68+
"""
69+
Returns a table mapping 256 entries to tanh([qmin,qmax])
70+
Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_tanh
71+
"""
72+
73+
def tanh(x):
74+
# Convert quantized input to floating point tanh input space.
75+
v = dequantize_value(x, in_quantargs)
76+
# Compute tanh.
77+
v = np.exp(-2.0 * v)
78+
v = (1.0 - v) / (1.0 + v)
79+
80+
# Convert tanh output back to quantized space.
81+
return quantize_value(v, out_quantargs)
82+
83+
return [
84+
tanh(x)
85+
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
86+
]

backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _annotate_one_to_one(
4141
torch.ops.aten.reciprocal.default,
4242
torch.ops.aten.rsqrt.default,
4343
torch.ops.aten.sigmoid.default,
44+
torch.ops.aten.tanh.default,
4445
)
4546
for node in gm.graph.nodes:
4647
if node.op != "call_function" or node.target not in one_to_one_ops:

backends/arm/test/ops/test_tanh.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
10+
from typing import Tuple
11+
12+
import torch
13+
14+
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16+
from executorch.exir.backend.compile_spec_schema import CompileSpec
17+
from parameterized import parameterized
18+
19+
20+
test_data_suite = [
21+
# (test_name, test_data)
22+
("zeros", torch.zeros(10, 10, 10, 10)),
23+
("ones", torch.ones(10, 10, 10)),
24+
("rand", torch.rand(10, 10) - 0.5),
25+
("randn_pos", torch.randn(10) + 10),
26+
("randn_neg", torch.randn(10) - 10),
27+
("ramp", torch.arange(-16, 16, 0.2)),
28+
]
29+
30+
31+
class TestTanh(unittest.TestCase):
32+
class Tanh(torch.nn.Module):
33+
def __init__(self):
34+
super().__init__()
35+
self.tanh = torch.nn.Tanh()
36+
37+
def forward(self, x):
38+
return self.tanh(x)
39+
40+
def _test_tanh_tosa_MI_pipeline(
41+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
42+
):
43+
(
44+
ArmTester(
45+
module,
46+
example_inputs=test_data,
47+
compile_spec=common.get_tosa_compile_spec(),
48+
)
49+
.export()
50+
.check(["torch.ops.aten.tanh.default"])
51+
.check_not(["torch.ops.quantized_decomposed"])
52+
.to_edge()
53+
.partition()
54+
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
55+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
56+
.to_executorch()
57+
.run_method_and_compare_outputs(inputs=test_data)
58+
)
59+
60+
def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
61+
(
62+
ArmTester(
63+
module,
64+
example_inputs=test_data,
65+
compile_spec=common.get_tosa_compile_spec(),
66+
)
67+
.quantize()
68+
.export()
69+
.check(["torch.ops.aten.tanh.default"])
70+
.check(["torch.ops.quantized_decomposed"])
71+
.to_edge()
72+
.partition()
73+
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
74+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
75+
.to_executorch()
76+
.run_method_and_compare_outputs(inputs=test_data)
77+
)
78+
79+
def _test_tanh_tosa_ethos_BI_pipeline(
80+
self,
81+
compile_spec: list[CompileSpec],
82+
module: torch.nn.Module,
83+
test_data: Tuple[torch.tensor],
84+
):
85+
(
86+
ArmTester(
87+
module,
88+
example_inputs=test_data,
89+
compile_spec=compile_spec,
90+
)
91+
.quantize()
92+
.export()
93+
.check_count({"torch.ops.aten.tanh.default": 1})
94+
.check(["torch.ops.quantized_decomposed"])
95+
.to_edge()
96+
.partition()
97+
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
98+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
99+
.to_executorch()
100+
)
101+
102+
def _test_tanh_tosa_u55_BI_pipeline(
103+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
104+
):
105+
self._test_tanh_tosa_ethos_BI_pipeline(
106+
common.get_u55_compile_spec(), module, test_data
107+
)
108+
109+
def _test_tanh_tosa_u85_BI_pipeline(
110+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
111+
):
112+
self._test_tanh_tosa_ethos_BI_pipeline(
113+
common.get_u85_compile_spec(), module, test_data
114+
)
115+
116+
@parameterized.expand(test_data_suite)
117+
def test_tanh_tosa_MI(
118+
self,
119+
test_name: str,
120+
test_data: torch.Tensor,
121+
):
122+
self._test_tanh_tosa_MI_pipeline(self.Tanh(), (test_data,))
123+
124+
@parameterized.expand(test_data_suite)
125+
def test_tanh_tosa_BI(self, test_name: str, test_data: torch.Tensor):
126+
self._test_tanh_tosa_BI_pipeline(self.Tanh(), (test_data,))
127+
128+
@parameterized.expand(test_data_suite)
129+
def test_tanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
130+
self._test_tanh_tosa_u55_BI_pipeline(self.Tanh(), (test_data,))
131+
132+
@parameterized.expand(test_data_suite)
133+
def test_tanh_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor):
134+
self._test_tanh_tosa_u85_BI_pipeline(self.Tanh(), (test_data,))

0 commit comments

Comments
 (0)