Skip to content

Commit bd90322

Browse files
authored
Qualcomm AI Engine Direct - issue fix (#13849)
### Summary - 13611 > add dilation support for transpose conv - 13604 > add dimension check for reduce operator - 13607 > fix oudated document & example code ### Test plan MATRIX = { TestQNNQuantizedOperator.test_qnn_backend_conv_transpose1d, TestQNNQuantizedOperator.test_qnn_backend_conv_transpose2d, test_qnn_backend_amax, test_qnn_backend_amin, } python backends/qualcomm/tests/test_qnn_delegate.py ${MATRIX} -b build-android -s $SN -m SM8750
1 parent d686472 commit bd90322

File tree

13 files changed

+150
-59
lines changed

13 files changed

+150
-59
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from .annotate_quant_attrs import AnnotateQuantAttrs
99
from .annotate_stack import AnnotateStack
1010
from .annotate_unbind import AnnotateUnbind
11+
from .canonicalize_conv import CanonicalizeConv
1112
from .convert_bmm_to_matmul import ConvertBmmToMatmul
12-
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
1313
from .convert_linear_to_conv2d import ConvertLinearToConv2d
1414
from .convert_square_to_pow import ConvertSquareToPow
1515
from .decompose_any import DecomposeAny
@@ -47,8 +47,8 @@
4747
AnnotateQuantAttrs,
4848
AnnotateStack,
4949
AnnotateUnbind,
50+
CanonicalizeConv,
5051
ConvertBmmToMatmul,
51-
ConvertConv1dToConv2d,
5252
ConvertLinearToConv2d,
5353
ConvertSquareToPow,
5454
DecomposeAny,

backends/qualcomm/_passes/convert_conv1d_to_conv2d.py renamed to backends/qualcomm/_passes/canonicalize_conv.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,96 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import cast, Tuple
8+
79
import torch
10+
811
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
912
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
1013
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch._guards import detect_fake_mode
1115

1216
from .utils import append_qdq, copy_meta
1317

1418

15-
class ConvertConv1dToConv2d(ExportPass):
19+
class CanonicalizeConv(ExportPass):
1620
"""
17-
Conv1d is not supported by QNN.
18-
Change it to input -> unsqueeze -> conv2d -> squeeze -> output
21+
1. QNN does not support dilation on TransposeConvND
22+
Dilate the kernel manually for math-equivalent operation
23+
2. Conv1d is not supported by QNN.
24+
Change it to input -> unsqueeze -> conv2d -> squeeze -> output
1925
"""
2026

2127
def __init__(self, edge_program: torch.export.ExportedProgram):
22-
super(ConvertConv1dToConv2d, self).__init__()
28+
super(CanonicalizeConv, self).__init__()
2329
self.edge_program = edge_program
24-
self.conv_op_map = {
30+
self.conv1d_op_map = {
2531
torch.ops.aten.conv1d.default: torch.ops.aten.conv2d.default,
2632
torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input,
2733
}
34+
self.transpose_conv_set = {
35+
torch.ops.aten.conv_transpose1d.default,
36+
torch.ops.aten.conv_transpose2d.input,
37+
}
38+
39+
def dilate(self, tensor, dilation):
40+
# e.g.
41+
# for 3x3 kernel with dilation == (2, 3)
42+
# 1, 0, 0, 2, 0, 0, 3
43+
# 1, 2, 3 0, 0, 0, 0, 0, 0, 0
44+
# 4, 5, 6 --> 4, 0, 0, 5, 0, 0, 6
45+
# 7, 8, 9 0, 0, 0, 0, 0, 0, 0
46+
# 7, 0, 0, 8, 0, 0, 9
47+
i, o, *k = tensor.shape
48+
new_k = [dim + (dim - 1) * (s - 1) for s, dim in zip(dilation, k)]
49+
new_tensor = torch.zeros((i, o, *new_k), dtype=tensor.dtype)
50+
indexing = (...,) + tuple([slice(None, None, d) for d in dilation])
51+
new_tensor[indexing] = tensor
52+
return new_tensor
2853

2954
def call(self, graph_module: torch.fx.GraphModule):
3055
graph = graph_module.graph
56+
# condition 1
57+
for node in graph.nodes:
58+
# arg order (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose2d.html)
59+
# > input, weight, bias, stride, padding, output_padding, groups, dilation
60+
if node.target in self.transpose_conv_set and len(node.args) > 7:
61+
dilation = cast(Tuple[int], node.args[7])
62+
# dilate kernel in advance
63+
filter_arg = node.args[1]
64+
filter_node = (
65+
# fp graph
66+
filter_arg
67+
if filter_arg.op == "placeholder"
68+
# qdq graph
69+
else node.args[1].args[0]
70+
)
71+
filter_tensor = self.dilate(
72+
get_parameter(filter_node, self.edge_program),
73+
dilation,
74+
)
75+
# update tensor meta for kernel node
76+
fake_mode = detect_fake_mode(filter_node.meta["val"])
77+
converter = fake_mode.fake_tensor_converter
78+
filter_node.meta["val"] = converter.from_real_tensor(
79+
fake_mode, filter_tensor
80+
)
81+
# update kernel
82+
set_parameter(
83+
(
84+
torch.nn.Parameter(filter_tensor)
85+
if filter_tensor.dtype == torch.float
86+
else filter_tensor
87+
),
88+
filter_node,
89+
self.edge_program,
90+
)
91+
# pop dilation for graph in cpu
92+
node.args = node.args[0:-1]
93+
94+
# condition 2
3195
for node in graph.nodes:
32-
if node.target in self.conv_op_map:
96+
if node.target in self.conv1d_op_map:
3397
input_node = node.args[0]
3498
with graph_module.graph.inserting_after(input_node):
3599
unsqueeze_op = torch.ops.aten.unsqueeze_copy.default
@@ -108,7 +172,7 @@ def call(self, graph_module: torch.fx.GraphModule):
108172
)
109173
conv2d_node = graph.create_node(
110174
"call_function",
111-
self.conv_op_map[node.target],
175+
self.conv1d_op_map[node.target],
112176
conv_args,
113177
)
114178
conv2d_node.meta = copy_meta(

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
AnnotateQuantAttrs,
1414
AnnotateStack,
1515
AnnotateUnbind,
16+
CanonicalizeConv,
1617
ConvertBmmToMatmul,
17-
ConvertConv1dToConv2d,
1818
ConvertLinearToConv2d,
1919
ConvertSquareToPow,
2020
DecomposeAny,
@@ -82,6 +82,7 @@ def get_capture_program_passes():
8282
(AnnotateQuantAttrs, True),
8383
(AnnotateStack, True),
8484
(AnnotateUnbind, True),
85+
(CanonicalizeConv, True),
8586
(ConvertBmmToMatmul, False),
8687
(DecomposeAny, True),
8788
(DecomposeColIm, True),
@@ -215,7 +216,7 @@ def transform_for_export_pipeline(
215216
self.add_pass(DecomposeWrapWithAutocast())
216217
# this pass will rewrite state_dict, it needs to be accomplished before
217218
# to_edge_transform_and_lower
218-
self.add_pass(ConvertConv1dToConv2d(exported_program))
219+
self.add_pass(CanonicalizeConv(exported_program))
219220
if convert_linear_to_conv2d:
220221
self.add_pass(ConvertLinearToConv2d(exported_program))
221222
self.add_pass(ConvertSquareToPow())

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,7 @@ def __init__(self, quantization_capture=False):
4141
)
4242

4343
def _dim_order_op_condition(self, node):
44-
dim_order = node.kwargs.get("dim_order")
45-
# skip if there contains layout hint
46-
# e.g. (0, 2, 3, 1) != (0, 1, 2, 3)
47-
if node.meta["val"].dtype != node.args[0].meta["val"].dtype:
48-
return False
49-
return dim_order != list(range(len(dim_order)))
44+
return node.meta["val"].dtype == node.args[0].meta["val"].dtype
5045

5146
def _to_copy_op_condition(self, node):
5247
return "memory_format" in node.kwargs

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def get_passes_dependency_for_capture_program():
6464
AnnotateQuantAttrs,
6565
AnnotateStack,
6666
AnnotateUnbind,
67+
CanonicalizeConv,
6768
ConvertBmmToMatmul,
68-
ConvertConv1dToConv2d,
6969
DecomposeAny,
7070
DecomposeColIm,
7171
DecomposeLinalgVectorNorm,
@@ -99,7 +99,7 @@ def get_passes_dependency_for_capture_program():
9999
I64toI32: [RemoveRedundancy],
100100
LayoutTransform: [
101101
AnnotateQuantAttrs,
102-
ConvertConv1dToConv2d,
102+
CanonicalizeConv,
103103
ExpandBroadcastTensorShape,
104104
FixedLinearKeepDim,
105105
],

backends/qualcomm/builders/op_amax.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,19 @@ def define_node(
4040
)
4141

4242
# mean dims and keep dims
43-
mean_dims = cast(List[int], node.args[1])
44-
mean_dims = [
45-
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
46-
]
47-
if QCOM_AXIS_ORDER in node.meta:
43+
if len(node.args) > 1:
44+
mean_dims = cast(List[int], node.args[1])
4845
mean_dims = [
49-
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
46+
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
5047
]
48+
if QCOM_AXIS_ORDER in node.meta:
49+
mean_dims = [
50+
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
51+
]
52+
else:
53+
# reduce all dimensions
54+
mean_dims = list(range(input_node.meta["val"].dim()))
55+
5156
mean_dims_shape = [len(mean_dims)]
5257

5358
output_tensor = self.get_tensor(node, node)

backends/qualcomm/builders/op_amin.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,19 @@ def define_node(
4040
)
4141

4242
# mean dims and keep dims
43-
mean_dims = cast(List[int], node.args[1])
44-
mean_dims = [
45-
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
46-
]
47-
if QCOM_AXIS_ORDER in node.meta:
43+
if len(node.args) > 1:
44+
mean_dims = cast(List[int], node.args[1])
4845
mean_dims = [
49-
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
46+
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
5047
]
48+
if QCOM_AXIS_ORDER in node.meta:
49+
mean_dims = [
50+
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
51+
]
52+
else:
53+
# reduce all dimensions
54+
mean_dims = list(range(input_node.meta["val"].dim()))
55+
5156
mean_dims_shape = [len(mean_dims)]
5257

5358
output_tensor = self.get_tensor(node, node)

backends/qualcomm/builders/op_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def define_node(
109109
input_tensor = self.get_tensor(input_node, node)
110110
assert (
111111
input_tensor.dim() == 4
112-
), "All Conv should be converted to Conv2D in ConvertConv1dToConv2d"
112+
), "All Conv1D should be converted to Conv2D in CanonicalizeConv,"
113113
input_tensor_wrapper = self.define_tensor(
114114
input_node,
115115
node,

backends/qualcomm/quantizer/annotators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
214214

215215
@register_annotator([torch.ops.aten.amax.default])
216216
def annotate_amax(node: Node, quantization_config: QuantizationConfig) -> None:
217-
annotate_binary(node, quantization_config)
217+
annotate_single_in_single_out(node, quantization_config)
218218

219219

220220
@register_annotator([torch.ops.aten.argmax.default])
@@ -224,7 +224,7 @@ def annotate_argmax(node: Node, quantization_config: QuantizationConfig) -> None
224224

225225
@register_annotator([torch.ops.aten.amin.default])
226226
def annotate_amin(node: Node, quantization_config: QuantizationConfig) -> None:
227-
annotate_binary(node, quantization_config)
227+
annotate_single_in_single_out(node, quantization_config)
228228

229229

230230
@register_annotator([torch.ops.aten.argmin.default])

backends/qualcomm/tests/models.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,25 +589,32 @@ def forward(self, x):
589589

590590

591591
class ConvTranspose1dSingle(torch.nn.Module):
592-
def __init__(self, bias=True):
592+
def __init__(self, bias=True, dilation=1):
593593
super().__init__()
594594
self.conv_transpose = torch.nn.ConvTranspose1d(
595-
in_channels=1, out_channels=3, kernel_size=3, stride=2, padding=1, bias=bias
595+
in_channels=1,
596+
out_channels=3,
597+
kernel_size=3,
598+
stride=2,
599+
padding=1,
600+
dilation=dilation,
601+
bias=bias,
596602
)
597603

598604
def forward(self, x):
599605
return self.conv_transpose(x)
600606

601607

602608
class ConvTranspose2dSingle(torch.nn.Module):
603-
def __init__(self, bias=True):
609+
def __init__(self, bias=True, dilation=1):
604610
super().__init__()
605611
self.conv_transpose = torch.nn.ConvTranspose2d(
606612
in_channels=1,
607613
out_channels=3,
608614
kernel_size=3,
609615
stride=2,
610616
padding=1,
617+
dilation=dilation,
611618
bias=bias,
612619
)
613620

0 commit comments

Comments
 (0)