1818 OpDepthWiseConv2d ,
1919 OpExpandDims ,
2020 OpReshape ,
21+ OpTransposeConv2d ,
2122 QNN_OP_PACKAGE_NAME_QTI_AISW ,
2223)
2324from .utils import get_parameter
@@ -42,6 +43,9 @@ def _add_conv_op_parameter(
4243 padding_shape ,
4344 dilation ,
4445 dilation_shape ,
46+ output_padding = None ,
47+ output_padding_shape = None ,
48+ transpose_conv = False ,
4549 groups = None ,
4650 ) -> PyQnnWrapper .PyQnnOpWrapper :
4751 """
@@ -68,14 +72,26 @@ def _add_conv_op_parameter(
6872 ),
6973 True ,
7074 )
71- conv_op .AddTensorParam (
72- OP .param_dilation ,
73- PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
74- len (dilation_shape ),
75- dilation_shape ,
76- np .array (dilation , dtype = np .uint32 ),
77- True ,
78- )
75+
76+ if transpose_conv :
77+ conv_op .AddTensorParam (
78+ OP .param_output_padding ,
79+ PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
80+ len (output_padding_shape ),
81+ output_padding_shape ,
82+ np .array (output_padding , dtype = np .uint32 ),
83+ True ,
84+ )
85+ else :
86+ conv_op .AddTensorParam (
87+ OP .param_dilation ,
88+ PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
89+ len (dilation_shape ),
90+ dilation_shape ,
91+ np .array (dilation , dtype = np .uint32 ),
92+ True ,
93+ )
94+
7995 if groups is not None :
8096 conv_op .AddScalarParam (
8197 OP .param_group ,
@@ -94,6 +110,11 @@ def _define_conv1d(
94110 Conv1D is a special case for convolutional operation. QNN does not support Conv1D, therefore,
95111 we need to cast from input -> Conv1d -> output to input -> unsqueeze -> Conv2d -> squeeze -> output.
96112 """
113+ transpose_conv = cast (bool , node .args [6 ])
114+ if transpose_conv :
115+ print ("ConvTranspose1d is not yet supported" )
116+ return
117+
97118 op_wrapper_list = [] # op_wrapper to return
98119 unsqueeze_input_node = node .args [0 ]
99120 input_quant_encoding , input_quant_configs = self .get_quant_encoding_conf (
@@ -239,9 +260,9 @@ def define_node(
239260 node : torch .fx .Node ,
240261 nodes_to_wrappers : Dict [str , PyQnnWrapper .TensorWrapper ],
241262 ) -> PyQnnWrapper .PyQnnOpWrapper :
242-
243263 if get_parameter (node .args [1 ], self .edge_program ).dim () == 3 :
244264 return self ._define_conv1d (node , nodes_to_wrappers )
265+
245266 input_node = node .args [0 ]
246267 input_tensor = self .get_tensor (input_node , node )
247268 input_tensor_wrapper = self .define_tensor (
@@ -254,8 +275,9 @@ def define_node(
254275
255276 filter_node = node .args [1 ]
256277 filter_tensor = get_parameter (filter_node , self .edge_program )
257- # weight of pytorch OIHW, yet QNN is HWIO
258- filter_axis_order = (2 , 3 , 1 , 0 )
278+ # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
279+ is_transpose_conv = cast (bool , node .args [6 ])
280+ filter_axis_order = (2 , 3 , 0 , 1 ) if is_transpose_conv else (2 , 3 , 1 , 0 )
259281 filter_tensor = filter_tensor .permute (dims = filter_axis_order ).contiguous ()
260282 filter_tensor_wrapper = self .define_tensor (
261283 filter_node ,
@@ -291,6 +313,7 @@ def define_node(
291313 stride = cast (List [int ], node .args [3 ])
292314 padding = cast (List [int ], node .args [4 ])
293315 dilation = cast (List [int ], node .args [5 ])
316+ output_padding = cast (List [int ], node .args [7 ])
294317
295318 groups = cast (int , node .args [8 ])
296319 # Qnn filter tensor is (H, W, Cin, Cout)
@@ -308,57 +331,38 @@ def define_node(
308331 if len (padding ) == 1 :
309332 padding = padding + padding
310333
311- # args[6] = transposed
312- if cast (bool , node .args [6 ]):
313- print ("Currently, No support for transposed convolution" )
314- return
315-
316- # args[7] = output padding
317- if not all (out_pad == 0 for out_pad in cast (List [int ], node .args [7 ])):
318- print ("QNN does not support output padding" )
319- return
320-
321334 stride_shape = [len (stride )]
322335 padding_shape = [2 , 2 ]
323336 dilation_shape = [len (dilation )]
337+ output_padding_shape = [len (output_padding )]
324338
325339 if is_depthwise_conv :
326- conv_op = PyQnnWrapper .PyQnnOpWrapper (
327- node .name ,
328- QNN_OP_PACKAGE_NAME_QTI_AISW ,
329- OpDepthWiseConv2d .op_name ,
330- )
331- conv_op = self ._add_conv_op_parameter (
332- OpDepthWiseConv2d ,
333- conv_op ,
334- conv_input_tensors ,
335- conv_output_tensors ,
336- stride ,
337- stride_shape ,
338- padding ,
339- padding_shape ,
340- dilation ,
341- dilation_shape ,
342- )
343-
340+ op_class = OpDepthWiseConv2d
341+ elif is_transpose_conv :
342+ op_class = OpTransposeConv2d
344343 else :
345- conv_op = PyQnnWrapper .PyQnnOpWrapper (
346- node .name ,
347- QNN_OP_PACKAGE_NAME_QTI_AISW ,
348- OpConv2d .op_name ,
349- )
350- conv_op = self ._add_conv_op_parameter (
351- OpConv2d ,
352- conv_op ,
353- conv_input_tensors ,
354- conv_output_tensors ,
355- stride ,
356- stride_shape ,
357- padding ,
358- padding_shape ,
359- dilation ,
360- dilation_shape ,
361- groups ,
362- )
344+ op_class = OpConv2d
345+
346+ conv_op = PyQnnWrapper .PyQnnOpWrapper (
347+ node .name ,
348+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
349+ op_class .op_name ,
350+ )
351+ conv_op = self ._add_conv_op_parameter (
352+ op_class ,
353+ conv_op ,
354+ conv_input_tensors ,
355+ conv_output_tensors ,
356+ stride ,
357+ stride_shape ,
358+ padding ,
359+ padding_shape ,
360+ dilation ,
361+ dilation_shape ,
362+ output_padding ,
363+ output_padding_shape ,
364+ is_transpose_conv ,
365+ None if is_depthwise_conv else groups ,
366+ )
363367
364368 return conv_op
0 commit comments