Skip to content

Commit dcc3978

Browse files
authored
Qualcomm AI Engine Direct - Suite Operator Test Support (Part1) (#14618)
### Summary - Support Add/Sub with alpha values - Support Conv3d - Support TransposeConv3d ### Test plan UT added
1 parent 6531b4a commit dcc3978

File tree

11 files changed

+354
-56
lines changed

11 files changed

+354
-56
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .convert_linear_to_conv2d import ConvertLinearToConv2d
1414
from .convert_square_to_pow import ConvertSquareToPow
1515
from .decompose_any import DecomposeAny
16+
from .decompose_binary_alpha import DecomposeBinaryAlpha
1617
from .decompose_cdist import DecomposeCDist
1718
from .decompose_col_im import DecomposeColIm
1819
from .decompose_einsum import DecomposeEinsum
@@ -54,6 +55,7 @@
5455
ConvertLinearToConv2d,
5556
ConvertSquareToPow,
5657
DecomposeAny,
58+
DecomposeBinaryAlpha,
5759
DecomposeCDist,
5860
DecomposeColIm,
5961
DecomposeEinsum,

backends/qualcomm/_passes/canonicalize_conv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
3434
self.transpose_conv_set = {
3535
torch.ops.aten.conv_transpose1d.default,
3636
torch.ops.aten.conv_transpose2d.input,
37+
torch.ops.aten.conv_transpose3d.input,
3738
}
3839

3940
def dilate(self, tensor, dilation):
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import copy_meta
11+
12+
decomp_set = {torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor}
13+
14+
15+
class DecomposeBinaryAlpha(ExportPass):
16+
"""
17+
QNN does not support alpha parameter for add/sub.
18+
Decompose to mul + add / mul + sub
19+
"""
20+
21+
def __init__(self) -> None:
22+
super().__init__()
23+
24+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
25+
graph = graph_module.graph
26+
for node in graph.nodes:
27+
if (
28+
node.target in decomp_set
29+
and "alpha" in node.kwargs
30+
and node.kwargs["alpha"] != 1
31+
):
32+
alpha = node.kwargs["alpha"]
33+
# Remove alpha from immutable dict
34+
node.kwargs = {k: v for k, v in node.kwargs.items() if k != "alpha"}
35+
input2_node = node.args[1]
36+
# If input2 is constant, we can just multiply the value for optimization
37+
if isinstance(input2_node, (int, float)):
38+
arg_list = list(node.args)
39+
arg_list[1] = input2_node * alpha
40+
node.args = tuple(arg_list)
41+
continue
42+
with graph.inserting_before(node):
43+
mul_op = torch.ops.aten.mul.Scalar
44+
mul_node = graph.create_node(
45+
"call_function",
46+
mul_op,
47+
(
48+
input2_node,
49+
alpha,
50+
),
51+
)
52+
mul_node.meta = copy_meta(node.meta)
53+
node.replace_input_with(input2_node, mul_node)
54+
node.args = (
55+
node.args[0],
56+
mul_node,
57+
)
58+
59+
graph.eliminate_dead_code()
60+
graph_module.recompile()
61+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ConvertLinearToConv2d,
1919
ConvertSquareToPow,
2020
DecomposeAny,
21+
DecomposeBinaryAlpha,
2122
DecomposeCDist,
2223
DecomposeColIm,
2324
DecomposeEinsum,
@@ -194,6 +195,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
194195
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
195196
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
196197
self.add_pass(ReplaceArangeArgs())
198+
self.add_pass(DecomposeBinaryAlpha())
197199
self.add_pass(DecomposeCDist())
198200
self.add_pass(DecomposeScaledDotProductAttention())
199201
self.add_pass(DecomposeRoll())
@@ -210,6 +212,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
210212
def transform_for_export_pipeline(
211213
self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False
212214
):
215+
self.add_pass(DecomposeBinaryAlpha())
213216
self.add_pass(DecomposeCDist())
214217
self.add_pass(DecomposeScaledDotProductAttention())
215218
self.add_pass(DecomposeRoll())

backends/qualcomm/builders/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ Please help update following table if you are contributing new operators:
365365
+ 🚫 = Deprecated, supported with other QNN Ops
366366

367367

368-
| Operators | HTP - 90/116 Enabled |
368+
| Operators | HTP - 92/116 Enabled |
369369
|-----------|---------|
370370
| Argmax | ✓ |
371371
| Argmin | ✓ |
@@ -375,7 +375,7 @@ Please help update following table if you are contributing new operators:
375375
| ChannelShuffle | ✗ |
376376
| Concat | ✓ |
377377
| Conv2d | ✓ |
378-
| Conv3d | ✗ |
378+
| Conv3d | ✓ |
379379
| Convert | ✓ |
380380
| CreateSparse | ✗ |
381381
| CumulativeSum | ✓ |
@@ -481,7 +481,7 @@ Please help update following table if you are contributing new operators:
481481
| TopK | ✓ |
482482
| TransPose | ✓ |
483483
| TransPoseConv2d | ✓ |
484-
| TransPoseConv3d | ✗ |
484+
| TransPoseConv3d | ✓ |
485485
| Unpack | ✓ |
486486

487487
## Issues

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
op_cat,
2525
op_ceil,
2626
op_clamp,
27-
op_conv2d,
27+
op_conv,
2828
op_copy,
2929
op_cos,
3030
op_cum_sum,
@@ -129,7 +129,7 @@
129129
op_cat,
130130
op_ceil,
131131
op_clamp,
132-
op_conv2d,
132+
op_conv,
133133
op_copy,
134134
op_cos,
135135
op_cum_sum,

backends/qualcomm/builders/op_conv2d.py renamed to backends/qualcomm/builders/op_conv.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import cast, Dict, List
88

99
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10-
1110
import numpy as np
1211
import torch
1312
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
@@ -16,8 +15,10 @@
1615
from .node_visitor_manager import register_node_visitor
1716
from .qnn_constants import (
1817
OpConv2d,
18+
OpConv3d,
1919
OpDepthWiseConv2d,
2020
OpTransposeConv2d,
21+
OpTransposeConv3d,
2122
QNN_OP_PACKAGE_NAME_QTI_AISW,
2223
)
2324
from .utils import get_parameter
@@ -66,7 +67,7 @@ def _add_conv_op_parameter(
6667
len(padding_shape),
6768
padding_shape,
6869
np.array(
69-
[[padding[0], padding[0]], [padding[1], padding[1]]],
70+
padding,
7071
dtype=np.uint32,
7172
),
7273
True,
@@ -108,8 +109,14 @@ def define_node(
108109
input_node = self.get_node(node.args[0])
109110
input_tensor = self.get_tensor(input_node, node)
110111
assert (
111-
input_tensor.dim() == 4
112+
input_tensor.dim() != 3
112113
), "All Conv1D should be converted to Conv2D in CanonicalizeConv,"
114+
assert input_tensor.dim() in {
115+
4,
116+
5,
117+
}, "Only Conv2d and Conv3d is supported in conv builder,"
118+
119+
is_conv2d = input_tensor.dim() == 4
113120
input_tensor_wrapper = self.define_tensor(
114121
input_node,
115122
node,
@@ -120,9 +127,15 @@ def define_node(
120127

121128
filter_node = self.get_node(node.args[1])
122129
filter_tensor = get_parameter(filter_node, self.edge_program)
123-
# weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
130+
# weight of pytorch OIHW(conv2d) / OIDHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d),
131+
# yet QNN is HWIO or DHWIO
124132
is_transpose_conv = cast(bool, node.args[6])
125-
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
133+
if is_conv2d:
134+
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
135+
else:
136+
filter_axis_order = (
137+
(2, 3, 4, 0, 1) if is_transpose_conv else (2, 3, 4, 1, 0)
138+
)
126139
filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
127140
filter_tensor_wrapper = self.define_tensor(
128141
filter_node,
@@ -132,7 +145,6 @@ def define_node(
132145
nodes_to_wrappers,
133146
)
134147
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
135-
136148
if node.args[2] is not None:
137149
bias_node = self.get_node(node.args[2])
138150
bias_tensor = get_parameter(bias_node, self.edge_program)
@@ -159,11 +171,10 @@ def define_node(
159171
padding = cast(List[int], node.args[4])
160172
dilation = cast(List[int], node.args[5])
161173
output_padding = cast(List[int], node.args[7])
162-
163174
groups = cast(int, node.args[8])
164-
# Qnn filter tensor is (H, W, Cin, Cout)
165-
group_input_channels = filter_tensor.shape[2]
166-
group_output_channels = int(filter_tensor.shape[3] / groups)
175+
# Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout)
176+
group_input_channels = filter_tensor.shape[-2]
177+
group_output_channels = int(filter_tensor.shape[-1] / groups)
167178
# 1) groups = input_channels (i.e. group_input_channels = 1)
168179
# 2) output_channels is a positive integer multiple of input channels
169180
# TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1
@@ -175,18 +186,23 @@ def define_node(
175186
)
176187
if len(padding) == 1:
177188
padding = padding + padding
189+
padding = [[x, x] for x in padding]
178190

179191
stride_shape = [len(stride)]
180-
padding_shape = [2, 2]
192+
padding_shape = [len(padding), len(padding[0])]
181193
dilation_shape = [len(dilation)]
182194
output_padding_shape = [len(output_padding)]
183195

184-
if is_depthwise_conv:
196+
if is_transpose_conv:
197+
assert all(
198+
val == 1 for val in dilation
199+
), "CanonicalizeConv pass should perform dilate for transpose_conv."
200+
op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d
201+
elif is_depthwise_conv:
202+
assert is_conv2d, "DepthWise only supports Conv2d"
185203
op_class = OpDepthWiseConv2d
186-
elif is_transpose_conv:
187-
op_class = OpTransposeConv2d
188204
else:
189-
op_class = OpConv2d
205+
op_class = OpConv2d if is_conv2d else OpConv3d
190206

191207
conv_op = PyQnnWrapper.PyQnnOpWrapper(
192208
node.name,

backends/qualcomm/builders/qnn_constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ class OpConv2d:
5959
param_dilation: str = "dilation"
6060

6161

62+
@dataclass(init=False, frozen=True)
63+
class OpConv3d:
64+
op_name: str = "Conv3d"
65+
param_stride: str = "stride"
66+
param_pad_amount: str = "pad_amount"
67+
param_group: str = "group"
68+
param_dilation: str = "dilation"
69+
70+
6271
@dataclass(init=False, frozen=True)
6372
class OpConvert:
6473
op_name: str = "Convert"
@@ -573,6 +582,15 @@ class OpTransposeConv2d:
573582
param_output_padding: str = "output_padding"
574583

575584

585+
@dataclass(init=False, frozen=True)
586+
class OpTransposeConv3d:
587+
op_name: str = "TransposeConv3d"
588+
param_stride: str = "stride"
589+
param_pad_amount: str = "pad_amount"
590+
param_group: str = "group"
591+
param_output_padding: str = "output_padding"
592+
593+
576594
@dataclass(init=False, frozen=True)
577595
class OpUnpack:
578596
op_name: str = "UnPack"

backends/qualcomm/quantizer/annotators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,11 +1094,13 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None:
10941094

10951095
@register_annotator(
10961096
[
1097+
torch.ops.aten.conv1d.default,
10971098
torch.ops.aten.conv2d.default,
10981099
torch.ops.aten.conv2d.padding,
1099-
torch.ops.aten.conv1d.default,
1100-
torch.ops.aten.conv_transpose2d.input,
1100+
torch.ops.aten.conv3d.default,
11011101
torch.ops.aten.conv_transpose1d.default,
1102+
torch.ops.aten.conv_transpose2d.input,
1103+
torch.ops.aten.conv_transpose3d.input,
11021104
torch.ops.aten.convolution.default,
11031105
]
11041106
)

0 commit comments

Comments
 (0)