Skip to content

Commit ecfa831

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Support cosine operator on XNNPACK
Summary: Wire up the unary cosine operator in xnnpack for fp32 and fp16. Differential Revision: D83623619
1 parent 30d7cae commit ecfa831

File tree

9 files changed

+157
-0
lines changed

9 files changed

+157
-0
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
op_ceiling,
1616
op_clamp,
1717
op_conv2d,
18+
op_cos,
1819
op_div,
1920
op_dynamic_dequantize_ops,
2021
op_dynamic_quantize_ops,
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNCos,
16+
XNNGraph,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class CosVisitor(NodeVisitor):
24+
target = "aten.cos.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNCos(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CeilConfig,
2424
ClampConfig,
2525
ConstantPadConfig,
26+
CosConfig,
2627
DeQuantizedPerTensorConfig,
2728
DivConfig,
2829
# EluConfig,
@@ -107,6 +108,7 @@
107108
ToDimOrderCopyConfig,
108109
SigmoidConfig,
109110
SinConfig,
111+
CosConfig,
110112
SliceCopyConfig,
111113
SoftmaxConfig,
112114
SquareRootConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,3 +643,9 @@ class SinConfig(GenericNodePartitionerConfig):
643643

644644
def supported_precision_types(self) -> List[ConfigPrecisionType]:
645645
return [ConfigPrecisionType.FP32]
646+
647+
class CosConfig(GenericNodePartitionerConfig):
648+
target_name = "cos.default"
649+
650+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
651+
return [ConfigPrecisionType.FP32]

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,7 @@ _DEFINE_UNARY_NODE_NO_PARAMS(Negate, xnn_unary_negate)
16941694
_DEFINE_UNARY_NODE_NO_PARAMS(Square, xnn_unary_square)
16951695
_DEFINE_UNARY_NODE_NO_PARAMS(Abs, xnn_unary_abs)
16961696
_DEFINE_UNARY_NODE_NO_PARAMS(Sin, xnn_unary_sine)
1697+
_DEFINE_UNARY_NODE_NO_PARAMS(Cos, xnn_unary_cosine)
16971698

16981699
// Unary Ops with min/max params
16991700
_DEFINE_UNARY_NODE_WITH_MINMAX(Clamp, xnn_unary_clamp)
@@ -1742,6 +1743,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
17421743
_DEFINE(PReLU)
17431744
_DEFINE(Sigmoid)
17441745
_DEFINE(Sin)
1746+
_DEFINE(Cos)
17451747

17461748
// Others
17471749
_DEFINE(FullyConnected)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ union XNodeUnion {
157157
XNNTanh: _XNNNode1x1,
158158
XNNExp: _XNNNode1x1,
159159
XNNSin: _XNNNode1x1,
160+
XNNCos: _XNNNode1x1,
160161
}
161162

162163
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ union XNodeUnion {
153153
XNNTanh: _XNNNode1x1,
154154
XNNExp: _XNNNode1x1,
155155
XNNSin: _XNNNode1x1,
156+
XNNCos: _XNNNode1x1,
156157
}
157158

158159
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ class XNNPReLU(XNNNode2x1):
351351
class XNNSin(XNNNode1x1):
352352
pass
353353

354+
@dataclass
355+
class XNNCos(XNNNode1x1):
356+
pass
357+
354358

355359
@dataclass
356360
class XNNScaledDotProductAttention:
@@ -409,6 +413,7 @@ class XNNScaledDotProductAttention:
409413
XNNTanh,
410414
XNNExp,
411415
XNNSin,
416+
XNNCos,
412417
]
413418

414419

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestCos(unittest.TestCase):
14+
def setUp(self):
15+
torch._dynamo.reset()
16+
17+
class Cos(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, x):
22+
z = torch.cos(x)
23+
return z
24+
25+
def _test_cos(self, inputs, legacy_mode: bool = False):
26+
tester = (
27+
Tester(self.Cos(), inputs)
28+
.export()
29+
.check_count({"torch.ops.aten.cos.default": 1})
30+
)
31+
32+
if legacy_mode:
33+
tester = tester.to_edge().partition()
34+
else:
35+
tester = tester.to_edge_transform_and_lower()
36+
37+
(
38+
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
39+
.check_not(["executorch_exir_dialects_edge__ops_aten_cos_default"])
40+
.to_executorch()
41+
.serialize()
42+
.run_method_and_compare_outputs()
43+
)
44+
45+
def test_fp16_cos(self):
46+
inputs = (
47+
torch.Tensor(
48+
[
49+
[0.0, 0.1, 0.5, 0.785398],
50+
[-0.5, -0.785398, 1.5708, -1.5708],
51+
],
52+
).to(torch.float16),
53+
)
54+
self._test_cos(inputs, legacy_mode=False)
55+
56+
def test_fp16_cos_legacy_mode(self):
57+
inputs = (
58+
torch.Tensor(
59+
[
60+
[0.0, 0.1, 0.5, 0.785398],
61+
[-0.5, -0.785398, 1.5708, -1.5708],
62+
],
63+
).to(torch.float16),
64+
)
65+
self._test_cos(inputs, legacy_mode=True)
66+
67+
def test_fp32_cos(self):
68+
inputs = (
69+
torch.Tensor(
70+
[
71+
[0.0, 0.1, 0.5, 0.785398],
72+
[-0.5, -0.785398, 1.5708, -1.5708],
73+
],
74+
),
75+
)
76+
self._test_cos(inputs, legacy_mode=False)
77+
78+
def test_fp32_cos_legacy_mode(self):
79+
inputs = (
80+
torch.Tensor(
81+
[
82+
[0.0, 0.1, 0.5, 0.785398],
83+
[-0.5, -0.785398, 1.5708, -1.5708],
84+
],
85+
),
86+
)
87+
self._test_cos(inputs, legacy_mode=True)

0 commit comments

Comments
 (0)