Skip to content

Commit 5bb48cc

Browse files
committed
Support Conv3d and TransposeConv3d
1 parent d9f3f62 commit 5bb48cc

File tree

7 files changed

+179
-53
lines changed

7 files changed

+179
-53
lines changed

backends/qualcomm/builders/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ Please help update following table if you are contributing new operators:
365365
+ 🚫 = Deprecated, supported with other QNN Ops
366366

367367

368-
| Operators | HTP - 90/116 Enabled |
368+
| Operators | HTP - 92/116 Enabled |
369369
|-----------|---------|
370370
| Argmax | ✓ |
371371
| Argmin | ✓ |
@@ -375,7 +375,7 @@ Please help update following table if you are contributing new operators:
375375
| ChannelShuffle | ✗ |
376376
| Concat | ✓ |
377377
| Conv2d | ✓ |
378-
| Conv3d | ✗ |
378+
| Conv3d | ✓ |
379379
| Convert | ✓ |
380380
| CreateSparse | ✗ |
381381
| CumulativeSum | ✓ |
@@ -481,7 +481,7 @@ Please help update following table if you are contributing new operators:
481481
| TopK | ✓ |
482482
| TransPose | ✓ |
483483
| TransPoseConv2d | ✓ |
484-
| TransPoseConv3d | ✗ |
484+
| TransPoseConv3d | ✓ |
485485
| Unpack | ✓ |
486486

487487
## Issues

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
op_cat,
2525
op_ceil,
2626
op_clamp,
27-
op_conv2d,
27+
op_conv,
2828
op_copy,
2929
op_cos,
3030
op_cum_sum,
@@ -129,7 +129,7 @@
129129
op_cat,
130130
op_ceil,
131131
op_clamp,
132-
op_conv2d,
132+
op_conv,
133133
op_copy,
134134
op_cos,
135135
op_cum_sum,
Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import cast, Dict, List
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10-
1111
import numpy as np
1212
import torch
1313
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
@@ -16,8 +16,10 @@
1616
from .node_visitor_manager import register_node_visitor
1717
from .qnn_constants import (
1818
OpConv2d,
19+
OpConv3d,
1920
OpDepthWiseConv2d,
2021
OpTransposeConv2d,
22+
OpTransposeConv3d,
2123
QNN_OP_PACKAGE_NAME_QTI_AISW,
2224
)
2325
from .utils import get_parameter
@@ -66,7 +68,7 @@ def _add_conv_op_parameter(
6668
len(padding_shape),
6769
padding_shape,
6870
np.array(
69-
[[padding[0], padding[0]], [padding[1], padding[1]]],
71+
padding,
7072
dtype=np.uint32,
7173
),
7274
True,
@@ -108,8 +110,14 @@ def define_node(
108110
input_node = self.get_node(node.args[0])
109111
input_tensor = self.get_tensor(input_node, node)
110112
assert (
111-
input_tensor.dim() == 4
113+
input_tensor.dim() != 3
112114
), "All Conv1D should be converted to Conv2D in CanonicalizeConv,"
115+
assert input_tensor.dim() in {
116+
4,
117+
5,
118+
}, "Only Conv2d and Conv3d is supported in conv builder,"
119+
120+
is_conv2d = input_tensor.dim() == 4
113121
input_tensor_wrapper = self.define_tensor(
114122
input_node,
115123
node,
@@ -122,7 +130,12 @@ def define_node(
122130
filter_tensor = get_parameter(filter_node, self.edge_program)
123131
# weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
124132
is_transpose_conv = cast(bool, node.args[6])
125-
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
133+
if is_conv2d:
134+
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
135+
else:
136+
filter_axis_order = (
137+
(2, 3, 4, 0, 1) if is_transpose_conv else (2, 3, 4, 1, 0)
138+
)
126139
filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
127140
filter_tensor_wrapper = self.define_tensor(
128141
filter_node,
@@ -132,7 +145,6 @@ def define_node(
132145
nodes_to_wrappers,
133146
)
134147
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
135-
136148
if node.args[2] is not None:
137149
bias_node = self.get_node(node.args[2])
138150
bias_tensor = get_parameter(bias_node, self.edge_program)
@@ -159,11 +171,10 @@ def define_node(
159171
padding = cast(List[int], node.args[4])
160172
dilation = cast(List[int], node.args[5])
161173
output_padding = cast(List[int], node.args[7])
162-
163174
groups = cast(int, node.args[8])
164-
# Qnn filter tensor is (H, W, Cin, Cout)
165-
group_input_channels = filter_tensor.shape[2]
166-
group_output_channels = int(filter_tensor.shape[3] / groups)
175+
# Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout)
176+
group_input_channels = filter_tensor.shape[-2]
177+
group_output_channels = int(filter_tensor.shape[-1] / groups)
167178
# 1) groups = input_channels (i.e. group_input_channels = 1)
168179
# 2) output_channels is a positive integer multiple of input channels
169180
# TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1
@@ -175,18 +186,36 @@ def define_node(
175186
)
176187
if len(padding) == 1:
177188
padding = padding + padding
189+
padding = [[x, x] for x in padding]
178190

179191
stride_shape = [len(stride)]
180-
padding_shape = [2, 2]
192+
padding_shape = [len(padding), len(padding[0])]
181193
dilation_shape = [len(dilation)]
182194
output_padding_shape = [len(output_padding)]
183195

196+
# Some transpose_conv3d restrictions by QNN
197+
if not is_conv2d and is_transpose_conv:
198+
if dilation[0] != 1:
199+
200+
warnings.warn(
201+
"[QNN Delegate Op Builder]: As of QNN 2.37, TransposeConv3d only supports dilation of 1 in the depth dimension.",
202+
stacklevel=1,
203+
)
204+
205+
# If using dilation, then stride height and width must = 1
206+
if any(x != 1 for x in dilation) and (stride[1] != 1 or stride[2] != 1):
207+
warnings.warn(
208+
"[QNN Delegate Op Builder]: As of QNN 2.37, TransposeConv3d, only supports dilation if stride in height and width = 1.",
209+
stacklevel=1,
210+
)
211+
184212
if is_depthwise_conv:
213+
assert is_conv2d, "DepthWise only supports Conv2d"
185214
op_class = OpDepthWiseConv2d
186215
elif is_transpose_conv:
187-
op_class = OpTransposeConv2d
216+
op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d
188217
else:
189-
op_class = OpConv2d
218+
op_class = OpConv2d if is_conv2d else OpConv3d
190219

191220
conv_op = PyQnnWrapper.PyQnnOpWrapper(
192221
node.name,

backends/qualcomm/builders/qnn_constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ class OpConv2d:
5959
param_dilation: str = "dilation"
6060

6161

62+
@dataclass(init=False, frozen=True)
63+
class OpConv3d:
64+
op_name: str = "Conv3d"
65+
param_stride: str = "stride"
66+
param_pad_amount: str = "pad_amount"
67+
param_group: str = "group"
68+
param_dilation: str = "dilation"
69+
70+
6271
@dataclass(init=False, frozen=True)
6372
class OpConvert:
6473
op_name: str = "Convert"
@@ -573,6 +582,15 @@ class OpTransposeConv2d:
573582
param_output_padding: str = "output_padding"
574583

575584

585+
@dataclass(init=False, frozen=True)
586+
class OpTransposeConv3d:
587+
op_name: str = "TransposeConv3d"
588+
param_stride: str = "stride"
589+
param_pad_amount: str = "pad_amount"
590+
param_group: str = "group"
591+
param_output_padding: str = "output_padding"
592+
593+
576594
@dataclass(init=False, frozen=True)
577595
class OpUnpack:
578596
op_name: str = "UnPack"

backends/qualcomm/quantizer/annotators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,11 +1094,13 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None:
10941094

10951095
@register_annotator(
10961096
[
1097+
torch.ops.aten.conv1d.default,
10971098
torch.ops.aten.conv2d.default,
10981099
torch.ops.aten.conv2d.padding,
1099-
torch.ops.aten.conv1d.default,
1100-
torch.ops.aten.conv_transpose2d.input,
1100+
torch.ops.aten.conv3d.default,
11011101
torch.ops.aten.conv_transpose1d.default,
1102+
torch.ops.aten.conv_transpose2d.input,
1103+
torch.ops.aten.conv_transpose3d.input,
11021104
torch.ops.aten.convolution.default,
11031105
]
11041106
)

backends/qualcomm/tests/models.py

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,28 @@ def forward(self, x):
597597
return self.second(self.first(x))
598598

599599

600+
class Conv3dSequential(torch.nn.Module):
601+
def __init__(self, bias=True):
602+
super().__init__()
603+
self.first = torch.nn.Conv3d(
604+
in_channels=1,
605+
out_channels=3,
606+
kernel_size=(3, 3, 3),
607+
padding=1,
608+
bias=bias,
609+
)
610+
self.second = torch.nn.Conv3d(
611+
in_channels=3,
612+
out_channels=2,
613+
kernel_size=(3, 3, 3),
614+
padding=1,
615+
bias=bias,
616+
)
617+
618+
def forward(self, x):
619+
return self.second(self.first(x))
620+
621+
600622
class Conv2dSingle(torch.nn.Module):
601623
def __init__(
602624
self,
@@ -619,40 +641,6 @@ def forward(self, x):
619641
return self.conv(x)
620642

621643

622-
class ConvTranspose1dSingle(torch.nn.Module):
623-
def __init__(self, bias=True, dilation=1):
624-
super().__init__()
625-
self.conv_transpose = torch.nn.ConvTranspose1d(
626-
in_channels=1,
627-
out_channels=3,
628-
kernel_size=3,
629-
stride=2,
630-
padding=1,
631-
dilation=dilation,
632-
bias=bias,
633-
)
634-
635-
def forward(self, x):
636-
return self.conv_transpose(x)
637-
638-
639-
class ConvTranspose2dSingle(torch.nn.Module):
640-
def __init__(self, bias=True, dilation=1):
641-
super().__init__()
642-
self.conv_transpose = torch.nn.ConvTranspose2d(
643-
in_channels=1,
644-
out_channels=3,
645-
kernel_size=3,
646-
stride=2,
647-
padding=1,
648-
dilation=dilation,
649-
bias=bias,
650-
)
651-
652-
def forward(self, x):
653-
return self.conv_transpose(x)
654-
655-
656644
class Conv2dDownUpSample(torch.nn.Module):
657645
def __init__(self, bias=True):
658646
super().__init__()
@@ -737,6 +725,57 @@ def forward(self, x):
737725
return topk_values
738726

739727

728+
class ConvTranspose1dSingle(torch.nn.Module):
729+
def __init__(self, bias=True, dilation=1):
730+
super().__init__()
731+
self.conv_transpose = torch.nn.ConvTranspose1d(
732+
in_channels=1,
733+
out_channels=3,
734+
kernel_size=3,
735+
stride=2,
736+
padding=1,
737+
dilation=dilation,
738+
bias=bias,
739+
)
740+
741+
def forward(self, x):
742+
return self.conv_transpose(x)
743+
744+
745+
class ConvTranspose2dSingle(torch.nn.Module):
746+
def __init__(self, bias=True, dilation=1):
747+
super().__init__()
748+
self.conv_transpose = torch.nn.ConvTranspose2d(
749+
in_channels=1,
750+
out_channels=3,
751+
kernel_size=3,
752+
stride=2,
753+
padding=1,
754+
dilation=dilation,
755+
bias=bias,
756+
)
757+
758+
def forward(self, x):
759+
return self.conv_transpose(x)
760+
761+
762+
class ConvTranspose3dSingle(torch.nn.Module):
763+
def __init__(self, bias=True, dilation=1):
764+
super().__init__()
765+
self.conv_transpose = torch.nn.ConvTranspose3d(
766+
in_channels=1,
767+
out_channels=3,
768+
kernel_size=1,
769+
stride=(1, 1, 1),
770+
padding=1,
771+
dilation=dilation,
772+
bias=bias,
773+
)
774+
775+
def forward(self, x):
776+
return self.conv_transpose(x)
777+
778+
740779
class Cos(torch.nn.Module):
741780
def __init__(self):
742781
super().__init__()

0 commit comments

Comments
 (0)