Skip to content

Commit d094b09

Browse files
winskuo-quicfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - Support topk (pytorch#5870)
Summary: - Support topK - Properly decompose einsum for quantization annotation to work properly - Unify warning messages in op builder - Add UT Pull Request resolved: pytorch#5870 Reviewed By: kirklandsign Differential Revision: D63947001 Pulled By: cccclai fbshipit-source-id: cd0e81abea48a2472a4791263407ecd17f91e906
1 parent 4bddf01 commit d094b09

19 files changed

+417
-26
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
from torch.fx.experimental.proxy_tensor import make_fx
10+
11+
12+
class DecomposeEinsum(ExportPass):
13+
"""
14+
Decompose einsum for quantization annotation to work properly.
15+
"""
16+
17+
def __init__(self) -> None:
18+
super().__init__()
19+
20+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
21+
graph = graph_module.graph
22+
for node in graph.nodes:
23+
if node.target == torch.ops.aten.einsum.default:
24+
decomposed_module = make_fx(
25+
node.target,
26+
tracing_mode="fake",
27+
)(node.args[0], [arg.meta["val"] for arg in node.args[1]])
28+
29+
with graph.inserting_before(node):
30+
# remap is used to map original node values to new node values,
31+
# which ensures that reference to nodes are correclty updated in the new graph
32+
remap = {}
33+
# Different from other nodes, einsum args[0] is the einsum equation,
34+
# while input nodes are stored in args[1]
35+
for i, arg in enumerate(node.args[1]):
36+
remap[f"arg1_{i+1}"] = arg
37+
38+
for decomposed_node in decomposed_module.graph.nodes:
39+
# This is the arg[0] equation string, which is not required anymore after decomposition
40+
if "arg0" in decomposed_node.name:
41+
continue
42+
43+
# no need to copy existent 'output'
44+
if decomposed_node.op == "output":
45+
for user in node.users.copy():
46+
# remap
47+
user.replace_input_with(
48+
node,
49+
remap[decomposed_node.args[0][0]],
50+
)
51+
# no need to copy existent placeholders
52+
elif decomposed_node.op == "placeholder":
53+
# replace node map from string to graph node
54+
remap[decomposed_node] = remap.pop(decomposed_node.name)
55+
else:
56+
remap[decomposed_node] = graph.node_copy(
57+
decomposed_node,
58+
arg_transform=lambda x, remap=remap: remap[x],
59+
)
60+
61+
graph.erase_node(node)
62+
63+
graph.eliminate_dead_code()
64+
graph_module.recompile()
65+
return PassResult(graph_module, True)

backends/qualcomm/_passes/insert_requantize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class InsertRequantize(ExportPass):
2828
# we don't use the 2nd output, 2nd output is an integer, etc.
2929
multi_output_op_ignore_set = {
3030
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
31+
exir_ops.edge.aten.topk.default,
3132
}
3233

3334
def __init__(

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class LayoutTransform(ExportPass):
6565
exir_ops.edge.aten.sqrt.default,
6666
exir_ops.edge.aten.sub.Tensor,
6767
exir_ops.edge.aten.sum.dim_IntList,
68+
exir_ops.edge.aten.topk.default,
6869
exir_ops.edge.aten._to_copy.default,
6970
exir_ops.edge.aten.split_with_sizes.default,
7071
*q_ops,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
op_sum_int_list,
5454
op_tanh,
5555
op_to,
56+
op_topk,
5657
op_transpose,
5758
op_unsqueeze,
5859
op_upsample_bilinear2d,
@@ -107,6 +108,7 @@
107108
op_sub,
108109
op_sum_int_list,
109110
op_tanh,
111+
op_topk,
110112
op_to,
111113
op_transpose,
112114
op_unsqueeze,

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -85,7 +86,10 @@ def define_node(
8586
if len(node.args) > 6:
8687
divisor_override = cast(int, node.args[6])
8788
if divisor_override != pooling_region:
88-
print("Not support divisor_override which is not equal to pooling region.")
89+
warnings.warn(
90+
"[QNN Delegate Op Builder]: Not support divisor_override which is not equal to pooling region.",
91+
stacklevel=1,
92+
)
8993
return
9094

9195
avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(

backends/qualcomm/builders/op_cat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -43,8 +44,9 @@ def define_node(
4344
)
4445

4546
if len(list_of_tensors) != len(list_of_tensor_wrappers):
46-
print(
47-
"The number or input tensors is not equal to the number of input tensor wrappers."
47+
warnings.warn(
48+
"[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.",
49+
stacklevel=1,
4850
)
4951
return
5052

backends/qualcomm/builders/op_conv2d.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import cast, Dict, List
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -189,12 +190,18 @@ def _define_conv1d(
189190

190191
# args[6] = transposed
191192
if cast(bool, node.args[6]):
192-
print("Currently, No support for transposed convolution")
193+
warnings.warn(
194+
"[QNN Delegate Op Builder]: Currently, No support for transposed convolution.",
195+
stacklevel=1,
196+
)
193197
return
194198

195199
# args[7] = output padding
196200
if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])):
197-
print("QNN does not support output padding")
201+
warnings.warn(
202+
"[QNN Delegate Op Builder]: QNN does not support output padding.",
203+
stacklevel=1,
204+
)
198205
return
199206

200207
stride_shape = [len(stride)]

backends/qualcomm/builders/op_expand.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -52,8 +53,9 @@ def define_node(
5253
output_dims = len(output_tensor.size())
5354

5455
if input_dims < output_dims:
55-
print(
56-
f"The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}."
56+
warnings.warn(
57+
f"[QNN Delegate Op Builder]: The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.",
58+
stacklevel=1,
5759
)
5860
return
5961

backends/qualcomm/builders/op_layer_norm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import Dict
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -44,7 +45,10 @@ def define_node(
4445
len(normalized_shapes) != 1
4546
and normalized_shapes[0] != input_tensor.shape[-1]
4647
):
47-
print("Only supports normalization with last input dimension")
48+
warnings.warn(
49+
"[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
50+
stacklevel=1,
51+
)
4852
return
4953
axis = [len(input_tensor.shape) - 1]
5054
axis_shape = [len(axis)]

backends/qualcomm/builders/op_linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import Dict
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -70,8 +71,9 @@ def define_node(
7071

7172
# TODO remove this when qnn sdk support
7273
if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}):
73-
print(
74-
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
74+
warnings.warn(
75+
f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.",
76+
stacklevel=1,
7577
)
7678
bias_tensor = get_parameter(bias_node, self.edge_program)
7779
bias_tensor_wrapper = self.define_tensor(

0 commit comments

Comments
 (0)