Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/xnnpack/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
op_relu,
op_rsqrt,
op_sigmoid,
op_sin,
op_skip_ops,
op_slice_copy,
op_softmax,
Expand Down
52 changes: 52 additions & 0 deletions backends/xnnpack/operators/op_sin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.backends.xnnpack.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNGraph,
XNNSin,
XNode,
)
from executorch.backends.xnnpack.utils.utils import get_input_node


@register_node_visitor
class SinVisitor(NodeVisitor):
target = "aten.sin.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
xnn_graph: XNNGraph,
vals_to_ids: Dict[torch.fx.Node, int],
debug_handle: int,
) -> None:
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)

# input
input_id = vals_to_ids[get_input_node(node, 0)]

# output
output_id = vals_to_ids[node]

ser_node = XNode(
xnode_union=XNNSin(
input_id=input_id,
output_id=output_id,
flags=0,
),
debug_handle=debug_handle,
)
xnn_graph.xnodes.append(ser_node)
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ReciprocalSquareRootConfig,
ReLUConfig,
SigmoidConfig,
SinConfig,
SliceCopyConfig,
SoftmaxConfig,
SquareRootConfig,
Expand Down Expand Up @@ -105,6 +106,7 @@
TanhConfig,
ToDimOrderCopyConfig,
SigmoidConfig,
SinConfig,
SliceCopyConfig,
SoftmaxConfig,
SquareRootConfig,
Expand Down
7 changes: 7 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,10 @@ class BMMConfig(GenericNodePartitionerConfig):

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class SinConfig(GenericNodePartitionerConfig):
target_name = "sin.default"

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]
2 changes: 2 additions & 0 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,7 @@ _DEFINE_UNARY_NODE_NO_PARAMS(Log, xnn_unary_log)
_DEFINE_UNARY_NODE_NO_PARAMS(Negate, xnn_unary_negate)
_DEFINE_UNARY_NODE_NO_PARAMS(Square, xnn_unary_square)
_DEFINE_UNARY_NODE_NO_PARAMS(Abs, xnn_unary_abs)
_DEFINE_UNARY_NODE_NO_PARAMS(Sin, xnn_unary_sine)

// Unary Ops with min/max params
_DEFINE_UNARY_NODE_WITH_MINMAX(Clamp, xnn_unary_clamp)
Expand Down Expand Up @@ -1737,6 +1738,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
_DEFINE(Floor)
_DEFINE(PReLU)
_DEFINE(Sigmoid)
_DEFINE(Sin)

// Others
_DEFINE(FullyConnected)
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/runtime_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ union XNodeUnion {
XNNGelu: _XNNNode1x1,
XNNTanh: _XNNNode1x1,
XNNExp: _XNNNode1x1,
XNNSin: _XNNNode1x1,
}

union XValueUnion {
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ union XNodeUnion {
XNNGelu: _XNNNode1x1,
XNNTanh: _XNNNode1x1,
XNNExp: _XNNNode1x1,
XNNSin: _XNNNode1x1,
}

union XValueUnion {
Expand Down
7 changes: 7 additions & 0 deletions backends/xnnpack/serialization/xnnpack_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,11 @@ class XNNPReLU(XNNNode2x1):
pass


@dataclass
class XNNSin(XNNNode1x1):
pass


@dataclass
class XNNScaledDotProductAttention:
query_id: int
Expand Down Expand Up @@ -402,6 +407,8 @@ class XNNScaledDotProductAttention:
XNNLog,
XNNGelu,
XNNTanh,
XNNExp,
XNNSin,
]


Expand Down
87 changes: 87 additions & 0 deletions backends/xnnpack/test/ops/test_sin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester


class TestSin(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()

class Sin(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
z = torch.sin(x)
return z

def _test_sin(self, inputs, legacy_mode: bool = False):
tester = (
Tester(self.Sin(), inputs)
.export()
.check_count({"torch.ops.aten.sin.default": 1})
)

if legacy_mode:
tester = tester.to_edge().partition()
else:
tester = tester.to_edge_transform_and_lower()

(
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_sin_default"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_fp16_sin(self):
inputs = (
torch.Tensor(
[
[0.0, 0.1, 0.5, 0.785398],
[-0.5, -0.785398, 1.5708, -1.5708],
],
).to(torch.float16),
)
self._test_sin(inputs, legacy_mode=False)

def test_fp16_sin_legacy_mode(self):
inputs = (
torch.Tensor(
[
[0.0, 0.1, 0.5, 0.785398],
[-0.5, -0.785398, 1.5708, -1.5708],
],
).to(torch.float16),
)
self._test_sin(inputs, legacy_mode=True)

def test_fp32_sin(self):
inputs = (
torch.Tensor(
[
[0.0, 0.1, 0.5, 0.785398],
[-0.5, -0.785398, 1.5708, -1.5708],
],
),
)
self._test_sin(inputs, legacy_mode=False)

def test_fp32_sin_legacy_mode(self):
inputs = (
torch.Tensor(
[
[0.0, 0.1, 0.5, 0.785398],
[-0.5, -0.785398, 1.5708, -1.5708],
],
),
)
self._test_sin(inputs, legacy_mode=True)
Loading