Skip to content

Commit 0c12dcd

Browse files
authored
Add tanh op to XNNPACK backend (#11804)
### Summary This PR adds support for the tanh operator in ExecuTorch via XNNPACK, enabling optimized execution of torch.tanh on the XNNPACK backend. The implementation includes updates to operator configuration, serialization, and runtime handling. The tanh operator is now properly registered in the XNNPACK partition config and mapped to XNNPACK's xnn_create_tanh_operator API in the compiler. ### Test plan I added a new test class TestTanh that is a simple torch model with a tanh op. It then asserts that the XNNPACK delegate was called while executing the tanh op instead of the torch default tanh op.
1 parent 4df9290 commit 0c12dcd

File tree

10 files changed

+145
-0
lines changed

10 files changed

+145
-0
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,6 @@
5050
op_static_constant_pad,
5151
op_static_resize_bilinear_2d,
5252
op_sub,
53+
op_tanh,
5354
op_to_copy,
5455
)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNTanh,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class TanhVisitor(NodeVisitor):
24+
target = "aten.tanh.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNTanh(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
SoftmaxConfig,
5050
SquareRootConfig,
5151
SubConfig,
52+
TanhConfig,
5253
UpsampleBilinear2dConfig,
5354
)
5455
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -99,6 +100,7 @@
99100
PreluConfig,
100101
ReciprocalSquareRootConfig,
101102
ReLUConfig,
103+
TanhConfig,
102104
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
103105
SigmoidConfig,
104106
SliceCopyConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
371371
return [ConfigPrecisionType.FP32]
372372

373373

374+
class TanhConfig(GenericNodePartitionerConfig):
375+
target_name = "tanh.default"
376+
377+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
378+
return [ConfigPrecisionType.FP32]
379+
380+
374381
class MeanDimConfig(GenericNodePartitionerConfig):
375382
target_name = "mean.dim"
376383

backends/xnnpack/partition/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
exir_ops.edge.aten.rsqrt.default,
6767
exir_ops.edge.aten.log.default,
6868
exir_ops.edge.aten.gelu.default,
69+
exir_ops.edge.aten.tanh.default,
6970
]
7071

7172
SUPPORTED_MODULES = [

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,36 @@ Error defineArgMaxPooling2dNode(
11621162
return Error::Ok;
11631163
}
11641164

1165+
/*
1166+
Define serialized tanh node into the subgraph, using the remapped ids
1167+
to map the serialized ids, to the new ids generated when defining the
1168+
tensor value
1169+
*/
1170+
Error defineTanhNode(
1171+
xnn_subgraph_t subgraph_ptr,
1172+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1173+
const NodePtr node,
1174+
const fb_xnnpack::XNNGraph* graph) noexcept {
1175+
MAYBE_UNUSED(graph);
1176+
1177+
auto graph_node = node->xnode_union_as_XNNTanh();
1178+
1179+
xnn_status status = xnn_define_tanh(
1180+
subgraph_ptr,
1181+
remapped_ids.at(graph_node->input_id()),
1182+
remapped_ids.at(graph_node->output_id()),
1183+
graph_node->flags());
1184+
1185+
ET_CHECK_OR_RETURN_ERROR(
1186+
status == xnn_status_success,
1187+
Internal,
1188+
"Failed to create tanh node %i with code: %s",
1189+
node->debug_handle(),
1190+
xnn_status_to_string(status));
1191+
1192+
return Error::Ok;
1193+
}
1194+
11651195
/*
11661196
Defines serialized prelu node into the subgraph,
11671197
using the remapped ids to map the serialized ids,
@@ -1697,6 +1727,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
16971727
_DEFINE(Gelu)
16981728
_DEFINE(Hardswish)
16991729
_DEFINE(Log)
1730+
_DEFINE(Tanh)
17001731
_DEFINE(Negate)
17011732
_DEFINE(Square)
17021733
_DEFINE(Clamp)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ union XNodeUnion {
154154
XNNReciprocalSquareRoot: _XNNNode1x1,
155155
XNNLog: _XNNNode1x1,
156156
XNNGelu: _XNNNode1x1,
157+
XNNTanh: _XNNNode1x1,
157158
}
158159

159160
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ union XNodeUnion {
150150
XNNReciprocalSquareRoot: _XNNNode1x1,
151151
XNNLog: _XNNNode1x1,
152152
XNNGelu: _XNNNode1x1,
153+
XNNTanh: _XNNNode1x1,
153154
}
154155

155156
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ class XNNLog(XNNNode1x1):
319319
pass
320320

321321

322+
@dataclass
323+
class XNNTanh(XNNNode1x1):
324+
pass
325+
326+
322327
@dataclass
323328
class XNNMaximum(XNNNode2x1):
324329
pass
@@ -391,6 +396,7 @@ class XNNScaledDotProductAttention:
391396
XNNReciprocalSquareRoot,
392397
XNNLog,
393398
XNNGelu,
399+
XNNTanh,
394400
]
395401

396402

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestTanh(unittest.TestCase):
14+
def setUp(self):
15+
torch._dynamo.reset()
16+
17+
class Tanh(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, x):
22+
return torch.tanh(x)
23+
24+
def run_tanh_test(self, inputs):
25+
(
26+
Tester(self.Tanh(), inputs)
27+
.export()
28+
.check_count({"torch.ops.aten.tanh.default": 1})
29+
.to_edge_transform_and_lower()
30+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
31+
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
32+
.to_executorch()
33+
.serialize()
34+
.run_method_and_compare_outputs()
35+
)
36+
37+
def test_fp16_tanh(self):
38+
inputs = (torch.randn(20).to(torch.float16),)
39+
self.run_tanh_test(inputs)
40+
41+
def test_fp32_tanh(self):
42+
inputs = (torch.randn(20),)
43+
self.run_tanh_test(inputs)

0 commit comments

Comments
 (0)