diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index 93424b1c84d..dd7d4f3e9fe 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -15,6 +15,7 @@ op_ceiling, op_clamp, op_conv2d, + op_cos, op_div, op_dynamic_dequantize_ops, op_dynamic_quantize_ops, diff --git a/backends/xnnpack/operators/op_cos.py b/backends/xnnpack/operators/op_cos.py new file mode 100644 index 00000000000..aa3166c96dd --- /dev/null +++ b/backends/xnnpack/operators/op_cos.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNCos, + XNNGraph, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class CosVisitor(NodeVisitor): + target = "aten.cos.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNCos( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 86baba3e3f7..6fbb95436e8 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -23,6 +23,7 @@ CeilConfig, ClampConfig, ConstantPadConfig, + CosConfig, DeQuantizedPerTensorConfig, DivConfig, # EluConfig, @@ -107,6 +108,7 @@ ToDimOrderCopyConfig, SigmoidConfig, SinConfig, + CosConfig, SliceCopyConfig, SoftmaxConfig, SquareRootConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 06024c632c9..52162036904 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -643,3 +643,9 @@ class SinConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] + +class CosConfig(GenericNodePartitionerConfig): + target_name = "cos.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 3e697566ce5..a1b075cb340 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1694,6 +1694,7 @@ _DEFINE_UNARY_NODE_NO_PARAMS(Negate, xnn_unary_negate) _DEFINE_UNARY_NODE_NO_PARAMS(Square, xnn_unary_square) _DEFINE_UNARY_NODE_NO_PARAMS(Abs, xnn_unary_abs) _DEFINE_UNARY_NODE_NO_PARAMS(Sin, xnn_unary_sine) +_DEFINE_UNARY_NODE_NO_PARAMS(Cos, xnn_unary_cosine) // Unary Ops with min/max params _DEFINE_UNARY_NODE_WITH_MINMAX(Clamp, xnn_unary_clamp) @@ -1742,6 +1743,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(PReLU) _DEFINE(Sigmoid) _DEFINE(Sin) + _DEFINE(Cos) // Others _DEFINE(FullyConnected) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index 239f92d899e..c624bb35f4f 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -157,6 +157,7 @@ union XNodeUnion { XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, XNNSin: _XNNNode1x1, + XNNCos: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 92a61c5537b..0967530ef66 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -153,6 +153,7 @@ union XNodeUnion { XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, XNNSin: _XNNNode1x1, + XNNCos: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 2b3f8e74202..00165c451e2 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -351,6 +351,10 @@ class XNNPReLU(XNNNode2x1): class XNNSin(XNNNode1x1): pass +@dataclass +class XNNCos(XNNNode1x1): + pass + @dataclass class XNNScaledDotProductAttention: @@ -409,6 +413,7 @@ class XNNScaledDotProductAttention: XNNTanh, XNNExp, XNNSin, + XNNCos, ] diff --git a/backends/xnnpack/test/ops/test_cos.py b/backends/xnnpack/test/ops/test_cos.py new file mode 100644 index 00000000000..5bec5c1a89e --- /dev/null +++ b/backends/xnnpack/test/ops/test_cos.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestCos(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Cos(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = torch.cos(x) + return z + + def _test_cos(self, inputs, legacy_mode: bool = False): + tester = ( + Tester(self.Cos(), inputs) + .export() + .check_count({"torch.ops.aten.cos.default": 1}) + ) + + if legacy_mode: + tester = tester.to_edge().partition() + else: + tester = tester.to_edge_transform_and_lower() + + ( + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_cos_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_cos(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ).to(torch.float16), + ) + self._test_cos(inputs, legacy_mode=False) + + def test_fp16_cos_legacy_mode(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ).to(torch.float16), + ) + self._test_cos(inputs, legacy_mode=True) + + def test_fp32_cos(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ), + ) + self._test_cos(inputs, legacy_mode=False) + + def test_fp32_cos_legacy_mode(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ), + ) + self._test_cos(inputs, legacy_mode=True)