diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index ec07502de54..35ce639978d 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -50,5 +50,6 @@ op_static_constant_pad, op_static_resize_bilinear_2d, op_sub, + op_tanh, op_to_copy, ) diff --git a/backends/xnnpack/operators/op_tanh.py b/backends/xnnpack/operators/op_tanh.py new file mode 100644 index 00000000000..6031839eceb --- /dev/null +++ b/backends/xnnpack/operators/op_tanh.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNTanh, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class TanhVisitor(NodeVisitor): + target = "aten.tanh.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNTanh( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 553b10f60d1..207d2cfd713 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -49,6 +49,7 @@ SoftmaxConfig, SquareRootConfig, SubConfig, + TanhConfig, UpsampleBilinear2dConfig, ) from executorch.backends.xnnpack.partition.config.node_configs import ( @@ -99,6 +100,7 @@ PreluConfig, ReciprocalSquareRootConfig, ReLUConfig, + TanhConfig, # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails SigmoidConfig, SliceCopyConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 46922e47010..e7e298053c6 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -371,6 +371,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] +class TanhConfig(GenericNodePartitionerConfig): + target_name = "tanh.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + class MeanDimConfig(GenericNodePartitionerConfig): target_name = "mean.dim" diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index 59a70d64a76..246f571c9c8 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -66,6 +66,7 @@ exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.log.default, exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.tanh.default, ] SUPPORTED_MODULES = [ diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 7241280ab35..b724ab1a9d9 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1162,6 +1162,36 @@ Error defineArgMaxPooling2dNode( return Error::Ok; } +/* +Define serialized tanh node into the subgraph, using the remapped ids +to map the serialized ids, to the new ids generated when defining the +tensor value +*/ +Error defineTanhNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNTanh(); + + xnn_status status = xnn_define_tanh( + subgraph_ptr, + remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create tanh node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Defines serialized prelu node into the subgraph, using the remapped ids to map the serialized ids, @@ -1697,6 +1727,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Gelu) _DEFINE(Hardswish) _DEFINE(Log) + _DEFINE(Tanh) _DEFINE(Negate) _DEFINE(Square) _DEFINE(Clamp) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index a0d44327912..eea4cdb8b86 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -154,6 +154,7 @@ union XNodeUnion { XNNReciprocalSquareRoot: _XNNNode1x1, XNNLog: _XNNNode1x1, XNNGelu: _XNNNode1x1, + XNNTanh: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index eeab28154cc..ed444005c64 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -150,6 +150,7 @@ union XNodeUnion { XNNReciprocalSquareRoot: _XNNNode1x1, XNNLog: _XNNNode1x1, XNNGelu: _XNNNode1x1, + XNNTanh: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index dc50fb47da4..106eb6b81d9 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -319,6 +319,11 @@ class XNNLog(XNNNode1x1): pass +@dataclass +class XNNTanh(XNNNode1x1): + pass + + @dataclass class XNNMaximum(XNNNode2x1): pass @@ -391,6 +396,7 @@ class XNNScaledDotProductAttention: XNNReciprocalSquareRoot, XNNLog, XNNGelu, + XNNTanh, ] diff --git a/backends/xnnpack/test/ops/test_tanh.py b/backends/xnnpack/test/ops/test_tanh.py new file mode 100644 index 00000000000..e7bac4541c9 --- /dev/null +++ b/backends/xnnpack/test/ops/test_tanh.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestTanh(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Tanh(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.tanh(x) + + def run_tanh_test(self, inputs): + ( + Tester(self.Tanh(), inputs) + .export() + .check_count({"torch.ops.aten.tanh.default": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_tanh(self): + inputs = (torch.randn(20).to(torch.float16),) + self.run_tanh_test(inputs) + + def test_fp32_tanh(self): + inputs = (torch.randn(20),) + self.run_tanh_test(inputs)