diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index d17b7abd6a1..93424b1c84d 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -41,6 +41,7 @@ op_relu, op_rsqrt, op_sigmoid, + op_sin, op_skip_ops, op_slice_copy, op_softmax, diff --git a/backends/xnnpack/operators/op_sin.py b/backends/xnnpack/operators/op_sin.py new file mode 100644 index 00000000000..56fe9396103 --- /dev/null +++ b/backends/xnnpack/operators/op_sin.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNSin, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class SinVisitor(NodeVisitor): + target = "aten.sin.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNSin( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index e393f1c9ac8..86baba3e3f7 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -45,6 +45,7 @@ ReciprocalSquareRootConfig, ReLUConfig, SigmoidConfig, + SinConfig, SliceCopyConfig, SoftmaxConfig, SquareRootConfig, @@ -105,6 +106,7 @@ TanhConfig, ToDimOrderCopyConfig, SigmoidConfig, + SinConfig, SliceCopyConfig, SoftmaxConfig, SquareRootConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 559d1522275..06024c632c9 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -636,3 +636,10 @@ class BMMConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] + + +class SinConfig(GenericNodePartitionerConfig): + target_name = "sin.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index eb9b668dafa..b71ab08ea45 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1690,6 +1690,7 @@ _DEFINE_UNARY_NODE_NO_PARAMS(Log, xnn_unary_log) _DEFINE_UNARY_NODE_NO_PARAMS(Negate, xnn_unary_negate) _DEFINE_UNARY_NODE_NO_PARAMS(Square, xnn_unary_square) _DEFINE_UNARY_NODE_NO_PARAMS(Abs, xnn_unary_abs) +_DEFINE_UNARY_NODE_NO_PARAMS(Sin, xnn_unary_sine) // Unary Ops with min/max params _DEFINE_UNARY_NODE_WITH_MINMAX(Clamp, xnn_unary_clamp) @@ -1737,6 +1738,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Floor) _DEFINE(PReLU) _DEFINE(Sigmoid) + _DEFINE(Sin) // Others _DEFINE(FullyConnected) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index 950318f18dc..239f92d899e 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -156,6 +156,7 @@ union XNodeUnion { XNNGelu: _XNNNode1x1, XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, + XNNSin: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index a4efc627cbb..92a61c5537b 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -152,6 +152,7 @@ union XNodeUnion { XNNGelu: _XNNNode1x1, XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, + XNNSin: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 99b64708f86..2b3f8e74202 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -347,6 +347,11 @@ class XNNPReLU(XNNNode2x1): pass +@dataclass +class XNNSin(XNNNode1x1): + pass + + @dataclass class XNNScaledDotProductAttention: query_id: int @@ -402,6 +407,8 @@ class XNNScaledDotProductAttention: XNNLog, XNNGelu, XNNTanh, + XNNExp, + XNNSin, ] diff --git a/backends/xnnpack/test/ops/test_sin.py b/backends/xnnpack/test/ops/test_sin.py new file mode 100644 index 00000000000..6a1b323e14c --- /dev/null +++ b/backends/xnnpack/test/ops/test_sin.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestSin(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Sin(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = torch.sin(x) + return z + + def _test_sin(self, inputs, legacy_mode: bool = False): + tester = ( + Tester(self.Sin(), inputs) + .export() + .check_count({"torch.ops.aten.sin.default": 1}) + ) + + if legacy_mode: + tester = tester.to_edge().partition() + else: + tester = tester.to_edge_transform_and_lower() + + ( + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_sin_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_sin(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ).to(torch.float16), + ) + self._test_sin(inputs, legacy_mode=False) + + def test_fp16_sin_legacy_mode(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ).to(torch.float16), + ) + self._test_sin(inputs, legacy_mode=True) + + def test_fp32_sin(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ), + ) + self._test_sin(inputs, legacy_mode=False) + + def test_fp32_sin_legacy_mode(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ), + ) + self._test_sin(inputs, legacy_mode=True)