Skip to content

Commit bf4acb3

Browse files
Qualcomm AI Engine Direct - Op enablement round, floor, atan (#12298)
### Summary • Enable round, floor, atan • test cases ### Test plan python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperators.test_qnn_backend_round -s $DEVICE_SERIAL -m SM8650 -b build-android/ python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperators.test_qnn_backend_round -s $DEVICE_SERIAL -m SM8650 -b build-android/ python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperators.test_qnn_backend_floor -s $DEVICE_SERIAL -m SM8650 -b build-android/ python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperators.test_qnn_backend_floor -s $DEVICE_SERIAL -m SM8650 -b build-android/ python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperators.test_qnn_backend_atan -s $DEVICE_SERIAL -m SM8650 -b build-android/ python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperators.test_qnn_backend_atan -s $DEVICE_SERIAL -m SM8650 -b build-android/ Author: @thchenqti Co-authored-by: thchenqti <[email protected]>
1 parent c228578 commit bf4acb3

File tree

10 files changed

+269
-4
lines changed

10 files changed

+269
-4
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class LayoutTransform(ExportPass):
6363
exir_ops.edge.aten.abs.default,
6464
exir_ops.edge.aten.add.Tensor,
6565
exir_ops.edge.aten.amax.default,
66+
exir_ops.edge.aten.atan.default,
6667
exir_ops.edge.aten.bitwise_or.Tensor,
6768
exir_ops.edge.aten.bmm.default,
6869
exir_ops.edge.aten.bitwise_and.Tensor,
@@ -75,6 +76,7 @@ class LayoutTransform(ExportPass):
7576
exir_ops.edge.aten.elu.default,
7677
exir_ops.edge.aten.eq.Tensor,
7778
exir_ops.edge.aten.exp.default,
79+
exir_ops.edge.aten.floor.default,
7880
exir_ops.edge.aten.full.default,
7981
exir_ops.edge.aten.full_like.default,
8082
exir_ops.edge.aten.ge.Tensor,
@@ -99,6 +101,7 @@ class LayoutTransform(ExportPass):
99101
exir_ops.edge.aten.pow.Tensor_Scalar,
100102
exir_ops.edge.aten.prelu.default,
101103
exir_ops.edge.aten.repeat.default,
104+
exir_ops.edge.aten.round.default,
102105
exir_ops.edge.aten.relu.default,
103106
exir_ops.edge.aten.sigmoid.default,
104107
exir_ops.edge.aten.split_with_sizes.default,

backends/qualcomm/builders/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ The operator now should be functional for Qualcomm backends. For operator to wor
360360
## Operator Support Status
361361
Please help update following table if you are contributing new operators:
362362

363-
| Operators | HTP - 77/116 Enabled |
363+
| Operators | HTP - 80/116 Enabled |
364364
|-----------|---------|
365365
| Argmax | &cross; |
366366
| Argmin | &check; |
@@ -382,14 +382,14 @@ Please help update following table if you are contributing new operators:
382382
| ElementWiseAdd | &check; |
383383
| ElementWiseAnd | &check; |
384384
| ElementWiseAsin | &cross; |
385-
| ElementWiseAtan | &cross; |
385+
| ElementWiseAtan | &check; |
386386
| ElementWiseBinary | &cross; |
387387
| ElementWiseCeil | &check; |
388388
| ElementWiseCos | &check; |
389389
| ElementWiseDivide | &check; |
390390
| ElementWiseEqual | &check; |
391391
| ElementWiseExp | &check; |
392-
| ElementWiseFloor | &cross; |
392+
| ElementWiseFloor | &check; |
393393
| ElementWiseFloorDiv | &cross; |
394394
| ElementWiseGreater | &check; |
395395
| ElementWiseGreaterEqual | &check; |
@@ -405,7 +405,7 @@ Please help update following table if you are contributing new operators:
405405
| ElementWiseNotEqual | &check; |
406406
| ElementWiseOr | &check; |
407407
| ElementWisePower | &check; |
408-
| ElementWiseRound | &cross; |
408+
| ElementWiseRound | &check; |
409409
| ElementWiseRsqrt | &check; |
410410
| ElementWiseSelect | &check; |
411411
| ElementWiseSign | &cross; |

backends/qualcomm/builders/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
op_and,
1414
op_arange,
1515
op_argmin,
16+
op_atan,
1617
op_avg_pool2d,
1718
op_batch_norm,
1819
op_bmm,
@@ -30,6 +31,7 @@
3031
op_eq,
3132
op_exp,
3233
op_expand,
34+
op_floor,
3335
op_full,
3436
op_full_like,
3537
op_gather,
@@ -68,6 +70,7 @@
6870
op_reshape,
6971
op_resize,
7072
op_rms_norm,
73+
op_round,
7174
op_rsqrt,
7275
op_scalar_tensor,
7376
op_select_copy,
@@ -103,6 +106,7 @@
103106
op_and,
104107
op_arange,
105108
op_argmin,
109+
op_atan,
106110
op_avg_pool2d,
107111
op_batch_norm,
108112
op_bmm,
@@ -120,6 +124,7 @@
120124
op_eq,
121125
op_exp,
122126
op_expand,
127+
op_floor,
123128
op_full,
124129
op_full_like,
125130
op_gather,
@@ -158,6 +163,7 @@
158163
op_reshape,
159164
op_resize,
160165
op_rms_norm,
166+
op_round,
161167
op_rsqrt,
162168
op_scalar_tensor,
163169
op_select_copy,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import torch
10+
11+
from .node_visitor import NodeVisitor
12+
from .node_visitor_manager import register_node_visitor
13+
from .qnn_constants import OpElementWiseAtan, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class Atan(NodeVisitor):
18+
target = ["aten.atan.default"]
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
input_node = self.get_node(node.args[0])
29+
input_tensor = self.get_tensor(input_node, node)
30+
input_tensor_wrapper = self.define_tensor(
31+
input_node,
32+
node,
33+
input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
)
37+
38+
output_tensor = self.get_tensor(node, node)
39+
output_tensor_wrapper = self.define_tensor(
40+
node,
41+
node,
42+
output_tensor,
43+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
44+
nodes_to_wrappers,
45+
)
46+
47+
atan_op = PyQnnWrapper.PyQnnOpWrapper(
48+
node.name,
49+
QNN_OP_PACKAGE_NAME_QTI_AISW,
50+
OpElementWiseAtan.op_name,
51+
)
52+
atan_op.AddInputTensors([input_tensor_wrapper])
53+
atan_op.AddOutputTensors([output_tensor_wrapper])
54+
55+
return atan_op
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import torch
10+
11+
from .node_visitor import NodeVisitor
12+
from .node_visitor_manager import register_node_visitor
13+
from .qnn_constants import OpElementWiseFloor, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class Floor(NodeVisitor):
18+
target = ["aten.floor.default"]
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
input_node = self.get_node(node.args[0])
29+
input_tensor = self.get_tensor(input_node, node)
30+
floor_inp_tensor_wrapper = self.define_tensor(
31+
input_node,
32+
node,
33+
input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
)
37+
floor_input_tensors = [floor_inp_tensor_wrapper]
38+
39+
output_tensor = self.get_tensor(node, node)
40+
output_tensor_wrapper = self.define_tensor(
41+
node,
42+
node,
43+
output_tensor,
44+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
45+
nodes_to_wrappers,
46+
)
47+
floor_output_tensors = [output_tensor_wrapper]
48+
49+
floor_op = PyQnnWrapper.PyQnnOpWrapper(
50+
node.name,
51+
QNN_OP_PACKAGE_NAME_QTI_AISW,
52+
OpElementWiseFloor.op_name,
53+
)
54+
floor_op.AddInputTensors(floor_input_tensors)
55+
floor_op.AddOutputTensors(floor_output_tensors)
56+
return floor_op
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import warnings
2+
from typing import Dict
3+
4+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
5+
import torch
6+
7+
from .node_visitor import NodeVisitor
8+
from .node_visitor_manager import register_node_visitor
9+
10+
from .qnn_constants import OpElementWiseRound, QNN_OP_PACKAGE_NAME_QTI_AISW
11+
12+
13+
@register_node_visitor
14+
class Round(NodeVisitor):
15+
target = ["aten.round.default"]
16+
17+
def __init__(self, *args) -> None:
18+
super().__init__(*args)
19+
20+
def define_node(
21+
self,
22+
node: torch.fx.Node,
23+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
24+
) -> PyQnnWrapper.PyQnnOpWrapper:
25+
input_node = self.get_node(node.args[0])
26+
input_tensor = self.get_tensor(input_node, node)
27+
input_tensor_wrapper = self.define_tensor(
28+
input_node,
29+
node,
30+
input_tensor,
31+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
32+
nodes_to_wrappers,
33+
)
34+
35+
if len(node.args) > 1:
36+
warnings.warn(
37+
"[QNN Delegate Op Builder]: QNN dose not support decimals",
38+
stacklevel=1,
39+
)
40+
return None
41+
42+
output_tensor = self.get_tensor(node, node)
43+
output_tensor_wrapper = self.define_tensor(
44+
node,
45+
node,
46+
output_tensor,
47+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
48+
nodes_to_wrappers,
49+
)
50+
51+
round_op = PyQnnWrapper.PyQnnOpWrapper(
52+
node.name,
53+
QNN_OP_PACKAGE_NAME_QTI_AISW,
54+
OpElementWiseRound.op_name,
55+
)
56+
round_op.AddInputTensors([input_tensor_wrapper])
57+
round_op.AddOutputTensors([output_tensor_wrapper])
58+
return round_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ class OpElementWiseAnd:
105105
op_name: str = "ElementWiseAnd"
106106

107107

108+
@dataclass(init=False, frozen=True)
109+
class OpElementWiseAtan:
110+
op_name: str = "ElementWiseAtan"
111+
112+
108113
@dataclass(init=False, frozen=True)
109114
class OpElementWiseCeil:
110115
op_name = "ElementWiseCeil"
@@ -130,6 +135,11 @@ class OpElementWiseEqual:
130135
op_name: str = "ElementWiseEqual"
131136

132137

138+
@dataclass(init=False, frozen=True)
139+
class OpElementWiseFloor:
140+
op_name: str = "ElementWiseFloor"
141+
142+
133143
@dataclass(init=False, frozen=True)
134144
class OpElementWiseGreater:
135145
op_name: str = "ElementWiseGreater"
@@ -203,6 +213,11 @@ class OpElementWisePower:
203213
op_name: str = "ElementWisePower"
204214

205215

216+
@dataclass(init=False, frozen=True)
217+
class OpElementWiseRound:
218+
op_name: str = "ElementWiseRound"
219+
220+
206221
@dataclass(init=False, frozen=True)
207222
class OpElementWiseRsqrt:
208223
op_name: str = "ElementWiseRsqrt"

backends/qualcomm/quantizer/annotators.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ def annotate_single_in_single_out(
163163
)
164164

165165

166+
@register_annotator([torch.ops.aten.atan.default])
167+
def annotate_atan(node: Node, quantization_config: QuantizationConfig) -> None:
168+
annotate_single_in_single_out(node, quantization_config)
169+
170+
166171
@register_annotator([torch.ops.aten.topk.default])
167172
def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None:
168173
if _is_annotated([node]):
@@ -404,6 +409,11 @@ def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None:
404409
annotate_single_in_single_out(node, quantization_config)
405410

406411

412+
@register_annotator([torch.ops.aten.floor.default])
413+
def annotate_floor(node: Node, quantization_config: QuantizationConfig) -> None:
414+
annotate_single_in_single_out(node, quantization_config)
415+
416+
407417
@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
408418
def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
409419
annotate_single_in_single_out(node, quantization_config)
@@ -414,6 +424,11 @@ def annotate_repeat(node: Node, quantization_config: QuantizationConfig) -> None
414424
annotate_single_in_single_out(node, quantization_config)
415425

416426

427+
@register_annotator([torch.ops.aten.round.default])
428+
def annotate_round(node: Node, quantization_config: QuantizationConfig) -> None:
429+
annotate_single_in_single_out(node, quantization_config)
430+
431+
417432
@register_annotator([torch.ops.aten.cos.default])
418433
def annotate_cos(node: Node, quantization_config: QuantizationConfig) -> None:
419434
annotate_single_in_single_out(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ def forward(self, x, y):
146146
return squeeze_out, conv_out
147147

148148

149+
class Atan(torch.nn.Module):
150+
def __init__(self):
151+
super().__init__()
152+
153+
def forward(self, x):
154+
return torch.atan(x)
155+
156+
149157
class AvgPoolModule(torch.nn.Module):
150158
def __init__(self, kernel_size, stride, padding, ceil_mode):
151159
super().__init__()
@@ -741,6 +749,14 @@ def forward(self, x):
741749
return torch.special.expm1(x)
742750

743751

752+
class Floor(torch.nn.Module):
753+
def __init__(self):
754+
super().__init__()
755+
756+
def forward(self, x):
757+
return torch.floor(x)
758+
759+
744760
class Fold(torch.nn.Module):
745761
def __init__(self):
746762
super().__init__()
@@ -1449,6 +1465,14 @@ def forward(self, x):
14491465
return torch.roll(x, shifts=self.shifts, dims=self.dims)
14501466

14511467

1468+
class Round(torch.nn.Module):
1469+
def __init__(self):
1470+
super().__init__()
1471+
1472+
def forward(self, x):
1473+
return torch.round(x)
1474+
1475+
14521476
class Rsqrt(torch.nn.Module):
14531477
def __init__(self):
14541478
super().__init__()

0 commit comments

Comments
 (0)