|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -import warnings |
8 | 7 | from typing import cast, Dict, List |
9 | 8 |
|
10 | 9 | import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper |
@@ -128,7 +127,8 @@ def define_node( |
128 | 127 |
|
129 | 128 | filter_node = self.get_node(node.args[1]) |
130 | 129 | filter_tensor = get_parameter(filter_node, self.edge_program) |
131 | | - # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO |
| 130 | + # weight of pytorch OIHW(conv2d) / IODHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d), |
| 131 | + # yet QNN is HWIO or DHWIO |
132 | 132 | is_transpose_conv = cast(bool, node.args[6]) |
133 | 133 | if is_conv2d: |
134 | 134 | filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0) |
@@ -193,23 +193,10 @@ def define_node( |
193 | 193 | dilation_shape = [len(dilation)] |
194 | 194 | output_padding_shape = [len(output_padding)] |
195 | 195 |
|
196 | | - # Some transpose_conv3d restrictions by QNN |
197 | | - if not is_conv2d and is_transpose_conv: |
198 | | - if dilation[0] != 1: |
199 | | - |
200 | | - warnings.warn( |
201 | | - "[QNN Delegate Op Builder]: As of QNN 2.37, TransposeConv3d only supports dilation of 1 in the depth dimension.", |
202 | | - stacklevel=1, |
203 | | - ) |
204 | | - |
205 | | - # If using dilation, then stride height and width must = 1 |
206 | | - if any(x != 1 for x in dilation) and (stride[1] != 1 or stride[2] != 1): |
207 | | - warnings.warn( |
208 | | - "[QNN Delegate Op Builder]: As of QNN 2.37, TransposeConv3d, only supports dilation if stride in height and width = 1.", |
209 | | - stacklevel=1, |
210 | | - ) |
211 | | - |
212 | 196 | if is_transpose_conv: |
| 197 | + assert all( |
| 198 | + val == 1 for val in dilation |
| 199 | + ), "CanonicalizeConv pass should perform dilate for transpose_conv." |
213 | 200 | op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d |
214 | 201 | elif is_depthwise_conv: |
215 | 202 | assert is_conv2d, "DepthWise only supports Conv2d" |
|
0 commit comments