Skip to content

Commit 7c85dc4

Browse files
authored
Qualcomm AI Engine Direct - Support Bitwise Or (#9224)
### Summary - Support BitWiseOr operator ### Test plan - Added UT
1 parent c066ee7 commit 7c85dc4

File tree

7 files changed

+149
-0
lines changed

7 files changed

+149
-0
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class LayoutTransform(ExportPass):
4747
layout_agnostic_ops = {
4848
exir_ops.edge.aten.abs.default,
4949
exir_ops.edge.aten.add.Tensor,
50+
exir_ops.edge.aten.bitwise_or.Tensor,
5051
exir_ops.edge.aten.bmm.default,
5152
exir_ops.edge.aten.cat.default,
5253
exir_ops.edge.aten.ceil.default,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
op_mul,
5353
op_ne,
5454
op_neg,
55+
op_or,
5556
op_pad,
5657
op_pow,
5758
op_prelu,
@@ -131,6 +132,7 @@
131132
op_mul,
132133
op_neg,
133134
op_ne,
135+
op_or,
134136
op_pad,
135137
op_pow,
136138
op_prelu,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
12+
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .qnn_constants import OpElementWiseOr, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class OpOr(NodeVisitor):
18+
target = ["aten.bitwise_or.Tensor"]
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
out_tensor = self.get_tensor(node, node)
29+
output_tensor_wrapper = self.define_tensor(
30+
node,
31+
node,
32+
out_tensor,
33+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
34+
nodes_to_wrappers,
35+
)
36+
or_output_tensors = [output_tensor_wrapper]
37+
38+
or_input_tensors = []
39+
for index in range(2):
40+
input_node = node.args[index]
41+
input_tensor = self.get_tensor(input_node, node)
42+
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
43+
44+
input_tensor_wrapper = self.define_tensor(
45+
input_node,
46+
node,
47+
input_tensor,
48+
tensor_type,
49+
nodes_to_wrappers,
50+
)
51+
or_input_tensors.append(input_tensor_wrapper)
52+
or_op = PyQnnWrapper.PyQnnOpWrapper(
53+
node.name,
54+
QNN_OP_PACKAGE_NAME_QTI_AISW,
55+
OpElementWiseOr.op_name,
56+
)
57+
or_op.AddInputTensors(or_input_tensors)
58+
or_op.AddOutputTensors(or_output_tensors)
59+
return or_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ class OpElementWiseNotEqual:
168168
op_name: str = "ElementWiseNotEqual"
169169

170170

171+
@dataclass(init=False, frozen=True)
172+
class OpElementWiseOr:
173+
op_name: str = "ElementWiseOr"
174+
175+
171176
@dataclass(init=False, frozen=True)
172177
class OpElementWisePower:
173178
op_name: str = "ElementWisePower"

backends/qualcomm/quantizer/annotators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,11 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non
680680
)
681681

682682

683+
@register_annotator([torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.__or__.Tensor])
684+
def annotate_bitwise_or(node: Node, quantization_config: QuantizationConfig) -> None:
685+
annotate_binary(node, quantization_config)
686+
687+
683688
@register_annotator([torch.ops.aten.pow.Tensor_Tensor])
684689
def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None:
685690
annotate_single_in_single_out(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,28 @@ def forward(self, x):
10251025
return x != self.constant
10261026

10271027

1028+
class OrBitWise(torch.nn.Module):
1029+
def __init__(self, pos, neg):
1030+
super().__init__()
1031+
self.pos = pos
1032+
self.neg = neg
1033+
1034+
def forward(self, x, y):
1035+
bitwise_or = torch.bitwise_or(x, y).bool()
1036+
return torch.where(bitwise_or, self.pos, self.neg)
1037+
1038+
1039+
class OrOperator(torch.nn.Module):
1040+
def __init__(self, pos, neg):
1041+
super().__init__()
1042+
self.pos = pos
1043+
self.neg = neg
1044+
1045+
def forward(self, x, y):
1046+
operator_or = x.to(torch.bool) | y.to(torch.bool)
1047+
return torch.where(operator_or, self.pos, self.neg)
1048+
1049+
10281050
class Pad(torch.nn.Module):
10291051
def __init__(self):
10301052
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,33 @@ def test_qnn_backend_element_wise_mul(self):
310310
self.lower_module_and_test_output(module, sample_input)
311311
index += 1
312312

313+
def test_qnn_backend_element_wise_or(self):
314+
test_comb = [
315+
{
316+
QCOM_MODULE: OrBitWise( # noqa: F405
317+
torch.tensor(1.7), torch.tensor(0.2)
318+
),
319+
QCOM_SAMPLE_INPUTS: (
320+
torch.tensor([1, 0, 1, 0], dtype=torch.bool),
321+
torch.tensor([1, 1, 0, 0], dtype=torch.bool),
322+
),
323+
},
324+
{
325+
QCOM_MODULE: OrOperator( # noqa: F405
326+
torch.tensor(1.5), torch.tensor(-1.2)
327+
),
328+
QCOM_SAMPLE_INPUTS: (
329+
torch.full((3, 3), 1).triu(),
330+
torch.full((3, 3), 1).tril(diagonal=0),
331+
),
332+
},
333+
]
334+
for i, test in enumerate(test_comb):
335+
with self.subTest(i=i):
336+
self.lower_module_and_test_output(
337+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
338+
)
339+
313340
def test_qnn_backend_element_wise_sqrt(self):
314341
modules = [Sqrt(), SqrtConstant()] # noqa: F405
315342
for i, module in enumerate(modules):
@@ -1246,6 +1273,34 @@ def test_qnn_backend_element_wise_mul(self):
12461273
self.lower_module_and_test_output(module, sample_input)
12471274
index += 1
12481275

1276+
def test_qnn_backend_element_wise_or(self):
1277+
test_comb = [
1278+
{
1279+
QCOM_MODULE: OrBitWise( # noqa: F405
1280+
torch.tensor(1.7), torch.tensor(0.2)
1281+
),
1282+
QCOM_SAMPLE_INPUTS: (
1283+
torch.tensor([1, 0, 1, 0], dtype=torch.bool),
1284+
torch.tensor([1, 1, 0, 0], dtype=torch.bool),
1285+
),
1286+
},
1287+
{
1288+
QCOM_MODULE: OrOperator( # noqa: F405
1289+
torch.tensor(1.5), torch.tensor(-1.2)
1290+
),
1291+
QCOM_SAMPLE_INPUTS: (
1292+
torch.full((3, 3), 1).triu(),
1293+
torch.full((3, 3), 1).tril(diagonal=0),
1294+
),
1295+
},
1296+
]
1297+
for i, test in enumerate(test_comb):
1298+
with self.subTest(i=i):
1299+
module = self.get_qdq_module(
1300+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
1301+
)
1302+
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])
1303+
12491304
def test_qnn_backend_element_wise_sqrt(self):
12501305
modules = [Sqrt(), SqrtConstant()] # noqa: F405
12511306
for i, module in enumerate(modules):

0 commit comments

Comments
 (0)