Skip to content

Commit f8dd1b8

Browse files
committed
Qualcomm AI Engine Direct - Support topk
1 parent a91eb8a commit f8dd1b8

19 files changed

+418
-26
lines changed

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
op_sum_int_list,
5454
op_tanh,
5555
op_to,
56+
op_topk,
5657
op_transpose,
5758
op_unsqueeze,
5859
op_upsample_bilinear2d,
@@ -107,6 +108,7 @@
107108
op_sub,
108109
op_sum_int_list,
109110
op_tanh,
111+
op_topk,
110112
op_to,
111113
op_transpose,
112114
op_unsqueeze,

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -85,7 +86,10 @@ def define_node(
8586
if len(node.args) > 6:
8687
divisor_override = cast(int, node.args[6])
8788
if divisor_override != pooling_region:
88-
print("Not support divisor_override which is not equal to pooling region.")
89+
warnings.warn(
90+
"[QNN Delegate Op Builder]: Not support divisor_override which is not equal to pooling region.",
91+
stacklevel=1,
92+
)
8993
return
9094

9195
avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(

backends/qualcomm/builders/op_cat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -43,8 +44,9 @@ def define_node(
4344
)
4445

4546
if len(list_of_tensors) != len(list_of_tensor_wrappers):
46-
print(
47-
"The number or input tensors is not equal to the number of input tensor wrappers."
47+
warnings.warn(
48+
"[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.",
49+
stacklevel=1,
4850
)
4951
return
5052

backends/qualcomm/builders/op_conv2d.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import cast, Dict, List
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -189,12 +190,18 @@ def _define_conv1d(
189190

190191
# args[6] = transposed
191192
if cast(bool, node.args[6]):
192-
print("Currently, No support for transposed convolution")
193+
warnings.warn(
194+
"[QNN Delegate Op Builder]: Currently, No support for transposed convolution.",
195+
stacklevel=1,
196+
)
193197
return
194198

195199
# args[7] = output padding
196200
if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])):
197-
print("QNN does not support output padding")
201+
warnings.warn(
202+
"[QNN Delegate Op Builder]: QNN does not support output padding.",
203+
stacklevel=1,
204+
)
198205
return
199206

200207
stride_shape = [len(stride)]

backends/qualcomm/builders/op_expand.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -52,8 +53,9 @@ def define_node(
5253
output_dims = len(output_tensor.size())
5354

5455
if input_dims < output_dims:
55-
print(
56-
f"The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}."
56+
warnings.warn(
57+
f"[QNN Delegate Op Builder]: The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.",
58+
stacklevel=1,
5759
)
5860
return
5961

backends/qualcomm/builders/op_layer_norm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import Dict
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -44,7 +45,10 @@ def define_node(
4445
len(normalized_shapes) != 1
4546
and normalized_shapes[0] != input_tensor.shape[-1]
4647
):
47-
print("Only supports normalization with last input dimension")
48+
warnings.warn(
49+
"[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
50+
stacklevel=1,
51+
)
4852
return
4953
axis = [len(input_tensor.shape) - 1]
5054
axis_shape = [len(axis)]

backends/qualcomm/builders/op_linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import Dict
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -70,8 +71,9 @@ def define_node(
7071

7172
# TODO remove this when qnn sdk support
7273
if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}):
73-
print(
74-
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
74+
warnings.warn(
75+
f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.",
76+
stacklevel=1,
7577
)
7678
bias_tensor = get_parameter(bias_node, self.edge_program)
7779
bias_tensor_wrapper = self.define_tensor(

backends/qualcomm/builders/op_max_pool2d.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -42,8 +43,9 @@ def define_node(
4243
if user.target.__name__ == "getitem":
4344
getitem_index = user.args[1]
4445
if getitem_index != 0:
45-
print(
46-
f"Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}"
46+
warnings.warn(
47+
f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}",
48+
stacklevel=1,
4749
)
4850
return
4951

@@ -78,8 +80,9 @@ def define_node(
7880
if len(node.args) > 4:
7981
dilation = cast(List[int], node.args[4])
8082
if not (dilation == 1 or dilation == [1, 1]):
81-
print(
82-
f"Not support dilation argument for max pool2d, but got {dilation}"
83+
warnings.warn(
84+
f"[QNN Delegate Op Builder]: Not support dilation argument for max pool2d, but got {dilation}",
85+
stacklevel=1,
8386
)
8487
return
8588

backends/qualcomm/builders/op_rms_norm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import Dict
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -47,7 +48,10 @@ def define_node(
4748
len(normalized_shapes) != 1
4849
and normalized_shapes[0] != input_tensor.shape[-1]
4950
):
50-
print("Only supports normalization with last input dimension")
51+
warnings.warn(
52+
"[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
53+
stacklevel=1,
54+
)
5155
return
5256
axes = [node.args[0].meta["val"].dim() - 1]
5357
axes_shape = [len(axes)]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import cast, Dict
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
11+
import numpy as np
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor, register_node_visitor
16+
from .qnn_constants import OpTopK, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class TopK(NodeVisitor):
21+
target = ["aten.topk.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30+
) -> PyQnnWrapper.PyQnnOpWrapper:
31+
32+
input_node = node.args[0]
33+
input_tensor = self.get_tensor(input_node, node)
34+
input_tensor_wrapper = self.define_tensor(
35+
input_node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
38+
nodes_to_wrappers,
39+
is_input_tensor=True,
40+
)
41+
42+
k = cast(int, node.args[1])
43+
44+
if len(node.args) > 2:
45+
dim = cast(int, node.args[2])
46+
if dim < 0:
47+
dim = dim % len(input_tensor.shape)
48+
if QCOM_AXIS_ORDER in node.meta:
49+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
50+
if dim != len(input_tensor.shape) - 1:
51+
warnings.warn(
52+
"[QNN Delegate Op Builder]: QNN currently only supports channel as dimension for topK.",
53+
stacklevel=1,
54+
)
55+
return
56+
57+
topk_input_tensors = [input_tensor_wrapper]
58+
59+
output_val_tensor = self.get_tensor(node, node, 0)
60+
output_idx_tensor = self.get_tensor(node, node, 1).to(torch.int32)
61+
62+
# QNN constraint, topk output_0 requires having the same quant config as input
63+
node.meta["quant_attrs"] = input_node.meta.get("quant_attrs")
64+
output_val_tensor_wrapper = self.define_tensor(
65+
node,
66+
output_val_tensor,
67+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
68+
nodes_to_wrappers,
69+
is_input_tensor=False,
70+
)
71+
72+
# topk output_1 is index, do not quantize it.
73+
node.meta.pop("quant_attrs", None)
74+
output_index_tensor_wrapper = self.define_tensor(
75+
node,
76+
output_idx_tensor,
77+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
78+
nodes_to_wrappers,
79+
is_input_tensor=False,
80+
wrapper_idx=1,
81+
)
82+
topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper]
83+
84+
topk_op = PyQnnWrapper.PyQnnOpWrapper(
85+
node.name,
86+
QNN_OP_PACKAGE_NAME_QTI_AISW,
87+
OpTopK.op_name,
88+
)
89+
topk_op.AddInputTensors(topk_input_tensors)
90+
topk_op.AddOutputTensors(topk_output_tensors)
91+
92+
topk_op.AddScalarParam(
93+
OpTopK.param_k,
94+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
95+
{"data": np.uint32(k)},
96+
)
97+
98+
# As of QNN 2.26, QNN HTP backend only allows users to set this value to 1, or else it will fail at op validation
99+
if len(node.args) > 3:
100+
largest = cast(bool, node.args[3])
101+
topk_op.AddScalarParam(
102+
OpTopK.param_largest,
103+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
104+
{QCOM_DATA: largest},
105+
)
106+
107+
return topk_op

0 commit comments

Comments
 (0)