Skip to content

Commit a90f1aa

Browse files
committed
Support bitwise and, xor, and or ops in Arm backend
Ops are very similar and thus clumped together. No quantization since that doesn't make sense for bitwise ops. Add a factory for creating simple two input NodeVisitors. This can be extended for future such ops. Signed-off-by: Erik Lundell <[email protected]> Change-Id: Ic3615067cd03b8775f4b028f601b6b2366487c9d
1 parent a70c6a3 commit a90f1aa

File tree

5 files changed

+281
-0
lines changed

5 files changed

+281
-0
lines changed

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9+
bitwise_support,
910
convolution_support,
1011
pool_2d_support,
1112
reduce_sum_support,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch.fx as fx
7+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
8+
register_tosa_support_check,
9+
SupportedTOSAOperatorCheck,
10+
)
11+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
@register_tosa_support_check
16+
class BitwiseSupported(SupportedTOSAOperatorCheck):
17+
targets = [
18+
exir_ops.edge.aten.bitwise_and.Tensor,
19+
exir_ops.edge.aten.bitwise_or.Tensor,
20+
exir_ops.edge.aten.bitwise_xor.Tensor,
21+
]
22+
23+
tosa_specs = [
24+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
25+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
26+
]
27+
28+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
29+
# U55 case, Vela 4.2.0 (25.02 release)
30+
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
31+
return False
32+
33+
return True

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,5 @@
4444
op_transpose,
4545
op_upsample_nearest2d,
4646
op_view,
47+
ops_binary,
4748
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import serializer.tosa_serializer as ts
11+
import torch
12+
import torch.fx
13+
14+
from executorch.backends.arm.operators.node_visitor import (
15+
NodeVisitor,
16+
register_node_visitor,
17+
)
18+
from executorch.backends.arm.tosa_mapping import TosaArg
19+
from serializer.tosa_serializer import TosaOp
20+
21+
22+
def binary_operator_factory(bw_target: str, tosa_op):
23+
"""Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op."""
24+
25+
class BinaryOperator(NodeVisitor):
26+
target = bw_target
27+
28+
def define_node(
29+
self,
30+
node: torch.fx.Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
) -> None:
35+
36+
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
37+
raise ValueError(
38+
"All inputs and outputs need same dtype."
39+
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}."
40+
)
41+
42+
tosa_graph.addOperator(
43+
tosa_op, [inputs[0].name, inputs[1].name], [output.name]
44+
)
45+
46+
register_node_visitor(BinaryOperator)
47+
48+
49+
binary_operator_factory("aten.bitwise_and.Tensor", TosaOp.Op().BITWISE_AND)
50+
binary_operator_factory("aten.bitwise_xor.Tensor", TosaOp.Op().BITWISE_XOR)
51+
binary_operator_factory("aten.bitwise_or.Tensor", TosaOp.Op().BITWISE_OR)
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
from typing import Callable, NamedTuple, Tuple
9+
10+
import torch
11+
from executorch.backends.arm.test import common, conftest
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
from parameterized import parameterized
14+
15+
16+
class DataTuple(NamedTuple):
17+
name: str
18+
tensor1: torch.Tensor
19+
tensor2: torch.Tensor
20+
21+
22+
class OpTuple(NamedTuple):
23+
name: str
24+
operator: torch.nn.Module
25+
26+
27+
class And(torch.nn.Module):
28+
def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
29+
return tensor1.bitwise_and(tensor2)
30+
31+
32+
class Xor(torch.nn.Module):
33+
def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
34+
return tensor1.bitwise_xor(tensor2)
35+
36+
37+
class Or(torch.nn.Module):
38+
def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
39+
return tensor1.bitwise_or(tensor2)
40+
41+
42+
test_data_suite: list[DataTuple] = [
43+
DataTuple(
44+
"zeros",
45+
torch.zeros(1, 10, 10, 10, dtype=torch.int32),
46+
torch.zeros(1, 10, 10, 10, dtype=torch.int32),
47+
),
48+
DataTuple(
49+
"ones",
50+
torch.ones(10, 10, 10, dtype=torch.int8),
51+
torch.ones(10, 10, 10, dtype=torch.int8),
52+
),
53+
DataTuple(
54+
"rand_rank2",
55+
torch.randint(-128, 127, (10, 10), dtype=torch.int8),
56+
torch.randint(-128, 127, (10, 10), dtype=torch.int8),
57+
),
58+
DataTuple(
59+
"rand_rank4",
60+
torch.randint(-128, -127, (1, 10, 10, 10), dtype=torch.int8),
61+
torch.randint(-128, 127, (1, 10, 10, 10), dtype=torch.int8),
62+
),
63+
]
64+
65+
66+
ops: list[OpTuple] = [
67+
OpTuple("and", And()),
68+
OpTuple("or", Or()),
69+
OpTuple("xor", Xor()),
70+
]
71+
72+
full_test_suite = []
73+
for op in ops:
74+
for test_data in test_data_suite:
75+
full_test_suite.append(
76+
(
77+
f"{op.name}_{test_data.name}",
78+
op.operator,
79+
test_data.tensor1,
80+
test_data.tensor2,
81+
)
82+
)
83+
84+
del test_data
85+
del ops
86+
87+
88+
class TestBitwise(unittest.TestCase):
89+
90+
def _test_bitwise_tosa_MI_pipeline(
91+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor, torch.tensor]
92+
):
93+
(
94+
ArmTester(
95+
module,
96+
example_inputs=test_data,
97+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
98+
)
99+
.export()
100+
.to_edge_transform_and_lower()
101+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
102+
.to_executorch()
103+
.run_method_and_compare_outputs(inputs=test_data)
104+
)
105+
106+
def _test_bitwise_tosa_BI_pipeline(
107+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor, torch.tensor]
108+
):
109+
(
110+
ArmTester(
111+
module,
112+
example_inputs=test_data,
113+
compile_spec=common.get_tosa_compile_spec(
114+
"TOSA-0.80+BI", custom_path="local_bin/bitwise"
115+
),
116+
)
117+
.export()
118+
.to_edge_transform_and_lower()
119+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
120+
.to_executorch()
121+
.run_method_and_compare_outputs(inputs=test_data)
122+
)
123+
124+
def _test_bitwise_tosa_u55_BI_pipeline(
125+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
126+
):
127+
# Tests that we don't delegate these ops since they are not supported on U55.
128+
(
129+
ArmTester(
130+
module,
131+
example_inputs=test_data,
132+
compile_spec=common.get_u55_compile_spec(),
133+
)
134+
.export()
135+
.to_edge_transform_and_lower()
136+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
137+
)
138+
139+
def _test_bitwise_tosa_u85_BI_pipeline(
140+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
141+
):
142+
tester = (
143+
ArmTester(
144+
module,
145+
example_inputs=test_data,
146+
compile_spec=common.get_u85_compile_spec(),
147+
)
148+
.export()
149+
.to_edge_transform_and_lower()
150+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
151+
.to_executorch()
152+
.serialize()
153+
)
154+
if conftest.is_option_enabled("corstone_fvp"):
155+
tester.run_method_and_compare_outputs(inputs=test_data)
156+
157+
@parameterized.expand(full_test_suite)
158+
def test_tosa_MI(
159+
self,
160+
test_name: str,
161+
operator: Callable,
162+
tensor1: torch.Tensor,
163+
tensor2: torch.Tensor,
164+
):
165+
self._test_bitwise_tosa_MI_pipeline(operator, (tensor1, tensor2))
166+
167+
@parameterized.expand(full_test_suite)
168+
def test_tosa_BI(
169+
self,
170+
test_name: str,
171+
operator: Callable,
172+
tensor1: torch.Tensor,
173+
tensor2: torch.Tensor,
174+
):
175+
self._test_bitwise_tosa_BI_pipeline(operator, (tensor1, tensor2))
176+
177+
@parameterized.expand(full_test_suite)
178+
def test_tosa_u55_BI(
179+
self,
180+
test_name: str,
181+
operator: Callable,
182+
tensor1: torch.Tensor,
183+
tensor2: torch.Tensor,
184+
):
185+
self._test_bitwise_tosa_u55_BI_pipeline(operator, (tensor1, tensor2))
186+
187+
@parameterized.expand(full_test_suite)
188+
def test_tosa_u85_BI(
189+
self,
190+
test_name: str,
191+
operator: Callable,
192+
tensor1: torch.Tensor,
193+
tensor2: torch.Tensor,
194+
):
195+
self._test_bitwise_tosa_u85_BI_pipeline(operator, (tensor1, tensor2))

0 commit comments

Comments
 (0)