Skip to content

Commit 7e99148

Browse files
committed
Code Review 2
1 parent 3bcc6de commit 7e99148

File tree

4 files changed

+12
-22
lines changed

4 files changed

+12
-22
lines changed

backends/qualcomm/_passes/canonicalize_conv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
3434
self.transpose_conv_set = {
3535
torch.ops.aten.conv_transpose1d.default,
3636
torch.ops.aten.conv_transpose2d.input,
37+
torch.ops.aten.conv_transpose3d.input,
3738
}
3839

3940
def dilate(self, tensor, dilation):

backends/qualcomm/builders/op_conv.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import warnings
87
from typing import cast, Dict, List
98

109
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -128,7 +127,8 @@ def define_node(
128127

129128
filter_node = self.get_node(node.args[1])
130129
filter_tensor = get_parameter(filter_node, self.edge_program)
131-
# weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
130+
# weight of pytorch OIHW(conv2d) / IODHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d),
131+
# yet QNN is HWIO or DHWIO
132132
is_transpose_conv = cast(bool, node.args[6])
133133
if is_conv2d:
134134
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
@@ -193,23 +193,10 @@ def define_node(
193193
dilation_shape = [len(dilation)]
194194
output_padding_shape = [len(output_padding)]
195195

196-
# Some transpose_conv3d restrictions by QNN
197-
if not is_conv2d and is_transpose_conv:
198-
if dilation[0] != 1:
199-
200-
warnings.warn(
201-
"[QNN Delegate Op Builder]: As of QNN 2.37, TransposeConv3d only supports dilation of 1 in the depth dimension.",
202-
stacklevel=1,
203-
)
204-
205-
# If using dilation, then stride height and width must = 1
206-
if any(x != 1 for x in dilation) and (stride[1] != 1 or stride[2] != 1):
207-
warnings.warn(
208-
"[QNN Delegate Op Builder]: As of QNN 2.37, TransposeConv3d, only supports dilation if stride in height and width = 1.",
209-
stacklevel=1,
210-
)
211-
212196
if is_transpose_conv:
197+
assert all(
198+
val == 1 for val in dilation
199+
), "CanonicalizeConv pass should perform dilate for transpose_conv."
213200
op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d
214201
elif is_depthwise_conv:
215202
assert is_conv2d, "DepthWise only supports Conv2d"

backends/qualcomm/tests/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,8 @@ def __init__(self, bias=True, dilation=1):
765765
self.conv_transpose = torch.nn.ConvTranspose3d(
766766
in_channels=1,
767767
out_channels=3,
768-
kernel_size=1,
769-
stride=(1, 1, 1),
768+
kernel_size=3,
769+
stride=2,
770770
padding=1,
771771
dilation=dilation,
772772
bias=bias,

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,8 @@ def test_qnn_backend_conv_transpose3d(self):
317317
modules = [
318318
ConvTranspose3dSingle(), # noqa: F405
319319
ConvTranspose3dSingle(bias=False), # noqa: F405
320-
ConvTranspose3dSingle(dilation=(1, 2, 3)), # noqa: F405
320+
ConvTranspose3dSingle(dilation=2), # noqa: F405
321+
ConvTranspose3dSingle(dilation=(3, 2, 3)), # noqa: F405
321322
]
322323
sample_input = (torch.randn([1, 1, 3, 3, 3]),)
323324
for i, module in enumerate(modules):
@@ -1897,7 +1898,8 @@ def test_qnn_backend_conv_transpose3d(self):
18971898
modules = [
18981899
ConvTranspose3dSingle(), # noqa: F405
18991900
ConvTranspose3dSingle(bias=False), # noqa: F405
1900-
ConvTranspose3dSingle(dilation=(1, 2, 3)), # noqa: F405
1901+
ConvTranspose3dSingle(dilation=2), # noqa: F405
1902+
ConvTranspose3dSingle(dilation=(3, 2, 3)), # noqa: F405
19011903
]
19021904
sample_input = (torch.randn([1, 1, 3, 3, 3]),)
19031905
for i, module in enumerate(modules):

0 commit comments

Comments
 (0)