Skip to content

Commit 2845fd3

Browse files
authored
Arm Backend: Split ops_unary into op_ceil, op_floor and op_logical_not (#14059)
Add op_ceil.py, op_floor.py, and op_logical_not.py instead of implementing the ceil, floor, and logical_not operators in ops_unary.py Signed-off-by: Agrima Khare <[email protected]>
1 parent 90153fa commit 2845fd3

File tree

5 files changed

+170
-69
lines changed

5 files changed

+170
-69
lines changed

backends/arm/operators/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@
1515
op_avg_pool2d,
1616
op_bmm,
1717
op_cat,
18+
op_ceil,
1819
op_clamp,
1920
op_constant_pad_nd,
2021
op_conv2d,
2122
op_cos,
2223
op_eq,
2324
op_erf,
2425
op_exp,
26+
op_floor,
2527
op_ge,
2628
op_gt,
2729
op_index_select,
2830
op_index_tensor,
2931
op_le,
3032
op_log,
33+
op_logical_not,
3134
op_lt,
3235
op_max_pool2d,
3336
op_maximum,
@@ -57,5 +60,4 @@
5760
op_where,
5861
ops_binary,
5962
ops_identity,
60-
ops_unary,
6163
)

backends/arm/operators/op_ceil.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch.fx
9+
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.operators.operator_validation_utils import (
15+
validate_num_inputs,
16+
validate_same_dtype,
17+
validate_valid_dtype,
18+
)
19+
from executorch.backends.arm.tosa import TosaSpecification
20+
21+
from executorch.backends.arm.tosa.mapping import TosaArg
22+
23+
24+
@register_node_visitor
25+
class CeilVisitor(NodeVisitor):
26+
target = "aten.ceil.default"
27+
28+
# INT case should be handled by op_table
29+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
30+
31+
def __init__(self, *args):
32+
super().__init__(*args)
33+
34+
def define_node(
35+
self,
36+
node: torch.fx.Node,
37+
tosa_graph: Any,
38+
inputs: List[TosaArg],
39+
output: TosaArg,
40+
) -> None:
41+
import serializer.tosa_serializer as ts # type: ignore # noqa: F401
42+
43+
validate_num_inputs(self.target, inputs, 1)
44+
validate_same_dtype(self.target, [*inputs, output], ts)
45+
validate_valid_dtype(
46+
self.target,
47+
inputs[0],
48+
ts.DType.FP32,
49+
output.tosa_spec,
50+
)
51+
52+
self._serialize_operator(
53+
node, tosa_graph, ts.TosaOp.Op().CEIL, [inputs[0].name], [output.name]
54+
)

backends/arm/operators/op_floor.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch.fx
9+
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.operators.operator_validation_utils import (
15+
validate_num_inputs,
16+
validate_same_dtype,
17+
validate_valid_dtype,
18+
)
19+
from executorch.backends.arm.tosa import TosaSpecification
20+
21+
from executorch.backends.arm.tosa.mapping import TosaArg
22+
23+
24+
@register_node_visitor
25+
class FloorVisitor(NodeVisitor):
26+
target = "aten.floor.default"
27+
28+
# INT case should be handled by op_table
29+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
30+
31+
def __init__(self, *args):
32+
super().__init__(*args)
33+
34+
def define_node(
35+
self,
36+
node: torch.fx.Node,
37+
tosa_graph: Any,
38+
inputs: List[TosaArg],
39+
output: TosaArg,
40+
) -> None:
41+
import serializer.tosa_serializer as ts # type: ignore # noqa: F401
42+
43+
validate_num_inputs(self.target, inputs, 1)
44+
validate_same_dtype(self.target, [*inputs, output], ts)
45+
validate_valid_dtype(
46+
self.target,
47+
inputs[0],
48+
ts.DType.FP32,
49+
output.tosa_spec,
50+
)
51+
52+
self._serialize_operator(
53+
node, tosa_graph, ts.TosaOp.Op().FLOOR, [inputs[0].name], [output.name]
54+
)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch.fx
9+
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.operators.operator_validation_utils import (
15+
validate_num_inputs,
16+
validate_same_dtype,
17+
validate_valid_dtype,
18+
)
19+
from executorch.backends.arm.tosa import TosaSpecification
20+
from executorch.backends.arm.tosa.mapping import TosaArg
21+
22+
23+
@register_node_visitor
24+
class LogicalNotVisitor(NodeVisitor):
25+
target = "aten.logical_not.default"
26+
27+
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
29+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
30+
]
31+
32+
def __init__(self, *args):
33+
super().__init__(*args)
34+
35+
def define_node(
36+
self,
37+
node: torch.fx.Node,
38+
tosa_graph: Any,
39+
inputs: List[TosaArg],
40+
output: TosaArg,
41+
) -> None:
42+
import serializer.tosa_serializer as ts # type: ignore # noqa: F401
43+
44+
validate_num_inputs(self.target, inputs, 1)
45+
validate_same_dtype(self.target, [*inputs, output], ts)
46+
validate_valid_dtype(
47+
self.target,
48+
[*inputs, output],
49+
[ts.DType.BOOL],
50+
output.tosa_spec,
51+
)
52+
53+
self._serialize_operator(
54+
node,
55+
tosa_graph,
56+
ts.TosaOp.Op().LOGICAL_NOT,
57+
[inputs[0].name],
58+
[output.name],
59+
)

backends/arm/operators/ops_unary.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

0 commit comments

Comments
 (0)