7
7
from typing import cast , Dict , List
8
8
9
9
import executorch .backends .qualcomm .python .PyQnnWrapperAdaptor as PyQnnWrapper
10
-
11
10
import numpy as np
12
11
import torch
13
12
from executorch .backends .qualcomm .utils .constants import QCOM_DATA
16
15
from .node_visitor_manager import register_node_visitor
17
16
from .qnn_constants import (
18
17
OpConv2d ,
18
+ OpConv3d ,
19
19
OpDepthWiseConv2d ,
20
20
OpTransposeConv2d ,
21
+ OpTransposeConv3d ,
21
22
QNN_OP_PACKAGE_NAME_QTI_AISW ,
22
23
)
23
24
from .utils import get_parameter
@@ -66,7 +67,7 @@ def _add_conv_op_parameter(
66
67
len (padding_shape ),
67
68
padding_shape ,
68
69
np .array (
69
- [[ padding [ 0 ], padding [ 0 ]], [ padding [ 1 ], padding [ 1 ]]] ,
70
+ padding ,
70
71
dtype = np .uint32 ,
71
72
),
72
73
True ,
@@ -108,8 +109,14 @@ def define_node(
108
109
input_node = self .get_node (node .args [0 ])
109
110
input_tensor = self .get_tensor (input_node , node )
110
111
assert (
111
- input_tensor .dim () == 4
112
+ input_tensor .dim () != 3
112
113
), "All Conv1D should be converted to Conv2D in CanonicalizeConv,"
114
+ assert input_tensor .dim () in {
115
+ 4 ,
116
+ 5 ,
117
+ }, "Only Conv2d and Conv3d is supported in conv builder,"
118
+
119
+ is_conv2d = input_tensor .dim () == 4
113
120
input_tensor_wrapper = self .define_tensor (
114
121
input_node ,
115
122
node ,
@@ -120,9 +127,15 @@ def define_node(
120
127
121
128
filter_node = self .get_node (node .args [1 ])
122
129
filter_tensor = get_parameter (filter_node , self .edge_program )
123
- # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
130
+ # weight of pytorch OIHW(conv2d) / OIDHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d),
131
+ # yet QNN is HWIO or DHWIO
124
132
is_transpose_conv = cast (bool , node .args [6 ])
125
- filter_axis_order = (2 , 3 , 0 , 1 ) if is_transpose_conv else (2 , 3 , 1 , 0 )
133
+ if is_conv2d :
134
+ filter_axis_order = (2 , 3 , 0 , 1 ) if is_transpose_conv else (2 , 3 , 1 , 0 )
135
+ else :
136
+ filter_axis_order = (
137
+ (2 , 3 , 4 , 0 , 1 ) if is_transpose_conv else (2 , 3 , 4 , 1 , 0 )
138
+ )
126
139
filter_tensor = filter_tensor .permute (dims = filter_axis_order ).contiguous ()
127
140
filter_tensor_wrapper = self .define_tensor (
128
141
filter_node ,
@@ -132,7 +145,6 @@ def define_node(
132
145
nodes_to_wrappers ,
133
146
)
134
147
conv_input_tensors = [input_tensor_wrapper , filter_tensor_wrapper ]
135
-
136
148
if node .args [2 ] is not None :
137
149
bias_node = self .get_node (node .args [2 ])
138
150
bias_tensor = get_parameter (bias_node , self .edge_program )
@@ -159,11 +171,10 @@ def define_node(
159
171
padding = cast (List [int ], node .args [4 ])
160
172
dilation = cast (List [int ], node .args [5 ])
161
173
output_padding = cast (List [int ], node .args [7 ])
162
-
163
174
groups = cast (int , node .args [8 ])
164
- # Qnn filter tensor is (H, W, Cin, Cout)
165
- group_input_channels = filter_tensor .shape [2 ]
166
- group_output_channels = int (filter_tensor .shape [3 ] / groups )
175
+ # Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout)
176
+ group_input_channels = filter_tensor .shape [- 2 ]
177
+ group_output_channels = int (filter_tensor .shape [- 1 ] / groups )
167
178
# 1) groups = input_channels (i.e. group_input_channels = 1)
168
179
# 2) output_channels is a positive integer multiple of input channels
169
180
# TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1
@@ -175,18 +186,23 @@ def define_node(
175
186
)
176
187
if len (padding ) == 1 :
177
188
padding = padding + padding
189
+ padding = [[x , x ] for x in padding ]
178
190
179
191
stride_shape = [len (stride )]
180
- padding_shape = [2 , 2 ]
192
+ padding_shape = [len ( padding ), len ( padding [ 0 ]) ]
181
193
dilation_shape = [len (dilation )]
182
194
output_padding_shape = [len (output_padding )]
183
195
184
- if is_depthwise_conv :
196
+ if is_transpose_conv :
197
+ assert all (
198
+ val == 1 for val in dilation
199
+ ), "CanonicalizeConv pass should perform dilate for transpose_conv."
200
+ op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d
201
+ elif is_depthwise_conv :
202
+ assert is_conv2d , "DepthWise only supports Conv2d"
185
203
op_class = OpDepthWiseConv2d
186
- elif is_transpose_conv :
187
- op_class = OpTransposeConv2d
188
204
else :
189
- op_class = OpConv2d
205
+ op_class = OpConv2d if is_conv2d else OpConv3d
190
206
191
207
conv_op = PyQnnWrapper .PyQnnOpWrapper (
192
208
node .name ,
0 commit comments