Skip to content

Commit c1633ed

Browse files
Arm backend: Support for sin and cos for TOSA 1.0
SIN and COS were introduced in TOSA 1.0. This patch adds support for both. It also adds unittests for both ops. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ic5a71a677b06045c067c0990b2b7f04ca5e98e2b
1 parent 0a06a5e commit c1633ed

File tree

10 files changed

+298
-1
lines changed

10 files changed

+298
-1
lines changed

backends/arm/_passes/insert_table_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class TableOps:
4848
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
4949
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
5050
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
51+
exir_ops.edge.aten.cos.default: torch.cos,
52+
exir_ops.edge.aten.sin.default: torch.sin,
5153
exir_ops.edge.aten.tanh.default: torch.tanh,
5254
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5355
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
pool_2d_support,
1313
reduce_sum_support,
1414
right_shift_support,
15+
sin_cos_support,
1516
slice_copy_support,
1617
to_copy_support,
1718
tosa_supported_operators,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
@register_tosa_support_check
19+
class SinCosSupported(SupportedTOSAOperatorCheck):
20+
targets = [
21+
exir_ops.edge.aten.cos.default,
22+
exir_ops.edge.aten.sin.default,
23+
]
24+
25+
tosa_specs = [
26+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
27+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
28+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
29+
]
30+
31+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
32+
return True

backends/arm/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
op_clamp,
1919
op_constant_pad_nd,
2020
op_conv2d,
21+
op_cos,
2122
op_eq,
2223
op_erf,
2324
op_exp,
@@ -38,6 +39,7 @@
3839
op_rshift_tensor,
3940
op_rsqrt,
4041
op_sigmoid,
42+
op_sin,
4143
op_slice,
4244
op_sub,
4345
op_sum,

backends/arm/operators/op_cos.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts # type: ignore
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
16+
from torch.fx import Node
17+
18+
19+
@register_node_visitor
20+
class CosVisitor(NodeVisitor):
21+
target = "aten.cos.default"
22+
23+
# INT case should be handled by op_table
24+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
25+
26+
def __init__(self, *args):
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: Node,
32+
tosa_graph: ts.TosaSerializer,
33+
inputs: List[TosaArg],
34+
output: TosaArg,
35+
) -> None:
36+
if len(node.all_input_nodes) != 1:
37+
raise ValueError(
38+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
39+
)
40+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
41+
raise ValueError(
42+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
43+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
44+
)
45+
46+
tosa_graph.addOperator(ts.TosaOp.Op().COS, [inputs[0].name], [output.name])

backends/arm/operators/op_sin.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts # type: ignore
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
16+
from torch.fx import Node
17+
18+
19+
@register_node_visitor
20+
class SinVisitor(NodeVisitor):
21+
target = "aten.sin.default"
22+
23+
# INT case should be handled by op_table
24+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
25+
26+
def __init__(self, *args):
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: Node,
32+
tosa_graph: ts.TosaSerializer,
33+
inputs: List[TosaArg],
34+
output: TosaArg,
35+
) -> None:
36+
if len(node.all_input_nodes) != 1:
37+
raise ValueError(
38+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
39+
)
40+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
41+
raise ValueError(
42+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
43+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
44+
)
45+
46+
tosa_graph.addOperator(ts.TosaOp.Op().SIN, [inputs[0].name], [output.name])

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def _match_pattern(
171171
torch.ops.aten.reciprocal.default,
172172
torch.ops.aten.rsqrt.default,
173173
torch.ops.aten.sigmoid.default,
174+
torch.ops.aten.cos.default,
175+
torch.ops.aten.sin.default,
174176
torch.ops.aten.tanh.default,
175177
torch.ops.aten.sum.dim_IntList,
176178
torch.ops.aten.hardsigmoid.default,

backends/arm/test/misc/test_multiple_delegates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_inputs(self):
2020

2121
def forward(self, x: torch.Tensor, y: torch.Tensor):
2222
z = x + y
23-
s = torch.sin(z)
23+
s = torch.tan(z)
2424
return s * z
2525

2626
def test_tosa_MI(self):

backends/arm/test/ops/test_cos.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
from executorch.backends.arm.test import common, conftest
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineBI,
14+
EthosU85PipelineBI,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
aten_op = "torch.ops.aten.cos.default"
20+
input_t1 = Tuple[torch.Tensor] # Input x
21+
22+
test_data_suite = {
23+
# (test_name, test_data)
24+
"zeros": torch.zeros(10, 10, 10, 10),
25+
"ones": torch.ones(10, 10, 10),
26+
"rand": torch.rand(10, 10) - 0.5,
27+
"randn_pos": torch.randn(10) + 10,
28+
"randn_neg": torch.randn(10) - 10,
29+
"ramp": torch.arange(-16, 16, 0.2),
30+
}
31+
32+
33+
class Cos(torch.nn.Module):
34+
35+
def forward(self, x: torch.Tensor):
36+
return torch.cos(x)
37+
38+
39+
@common.parametrize("test_data", test_data_suite)
40+
def test_cos_tosa_MI(test_data: Tuple):
41+
pipeline = TosaPipelineMI[input_t1](
42+
Cos(),
43+
(test_data,),
44+
aten_op,
45+
exir_op=[],
46+
)
47+
if conftest.get_option("tosa_version") == "1.0":
48+
pipeline.run()
49+
50+
51+
@common.parametrize("test_data", test_data_suite)
52+
def test_cos_tosa_BI(test_data: Tuple):
53+
pipeline = TosaPipelineBI[input_t1](
54+
Cos(),
55+
(test_data,),
56+
aten_op,
57+
exir_op=[],
58+
)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", test_data_suite)
63+
def test_cos_tosa_u55_BI(test_data: Tuple):
64+
pipeline = EthosU55PipelineBI[input_t1](
65+
Cos(),
66+
(test_data,),
67+
aten_op,
68+
exir_ops=[],
69+
run_on_fvp=False,
70+
)
71+
pipeline.run()
72+
73+
74+
@common.parametrize("test_data", test_data_suite)
75+
def test_cos_tosa_u85_BI(test_data: Tuple):
76+
pipeline = EthosU85PipelineBI[input_t1](
77+
Cos(),
78+
(test_data,),
79+
aten_op,
80+
exir_ops=[],
81+
run_on_fvp=False,
82+
)
83+
pipeline.run()

backends/arm/test/ops/test_sin.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
from executorch.backends.arm.test import common, conftest
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineBI,
14+
EthosU85PipelineBI,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
aten_op = "torch.ops.aten.sin.default"
20+
input_t1 = Tuple[torch.Tensor] # Input x
21+
22+
test_data_suite = {
23+
# (test_name, test_data)
24+
"zeros": torch.zeros(10, 10, 10, 10),
25+
"ones": torch.ones(10, 10, 10),
26+
"rand": torch.rand(10, 10) - 0.5,
27+
"randn_pos": torch.randn(10) + 10,
28+
"randn_neg": torch.randn(10) - 10,
29+
"ramp": torch.arange(-16, 16, 0.2),
30+
}
31+
32+
33+
class Sin(torch.nn.Module):
34+
35+
def forward(self, x: torch.Tensor):
36+
return torch.sin(x)
37+
38+
39+
@common.parametrize("test_data", test_data_suite)
40+
def test_sin_tosa_MI(test_data: Tuple):
41+
pipeline = TosaPipelineMI[input_t1](
42+
Sin(),
43+
(test_data,),
44+
aten_op,
45+
exir_op=[],
46+
)
47+
if conftest.get_option("tosa_version") == "1.0":
48+
pipeline.run()
49+
50+
51+
@common.parametrize("test_data", test_data_suite)
52+
def test_sin_tosa_BI(test_data: Tuple):
53+
pipeline = TosaPipelineBI[input_t1](
54+
Sin(),
55+
(test_data,),
56+
aten_op,
57+
exir_op=[],
58+
)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", test_data_suite)
63+
def test_sin_tosa_u55_BI(test_data: Tuple):
64+
pipeline = EthosU55PipelineBI[input_t1](
65+
Sin(),
66+
(test_data,),
67+
aten_op,
68+
exir_ops=[],
69+
run_on_fvp=False,
70+
)
71+
pipeline.run()
72+
73+
74+
@common.parametrize("test_data", test_data_suite)
75+
def test_sin_tosa_u85_BI(test_data: Tuple):
76+
pipeline = EthosU85PipelineBI[input_t1](
77+
Sin(),
78+
(test_data,),
79+
aten_op,
80+
exir_ops=[],
81+
run_on_fvp=False,
82+
)
83+
pipeline.run()

0 commit comments

Comments
 (0)