77from typing import cast , Dict , List
88
99import executorch .backends .qualcomm .python .PyQnnWrapperAdaptor as PyQnnWrapper
10-
1110import numpy as np
1211import torch
1312from executorch .backends .qualcomm .utils .constants import QCOM_DATA
1615from .node_visitor_manager import register_node_visitor
1716from .qnn_constants import (
1817 OpConv2d ,
18+ OpConv3d ,
1919 OpDepthWiseConv2d ,
2020 OpTransposeConv2d ,
21+ OpTransposeConv3d ,
2122 QNN_OP_PACKAGE_NAME_QTI_AISW ,
2223)
2324from .utils import get_parameter
@@ -66,7 +67,7 @@ def _add_conv_op_parameter(
6667 len (padding_shape ),
6768 padding_shape ,
6869 np .array (
69- [[ padding [ 0 ], padding [ 0 ]], [ padding [ 1 ], padding [ 1 ]]] ,
70+ padding ,
7071 dtype = np .uint32 ,
7172 ),
7273 True ,
@@ -108,8 +109,14 @@ def define_node(
108109 input_node = self .get_node (node .args [0 ])
109110 input_tensor = self .get_tensor (input_node , node )
110111 assert (
111- input_tensor .dim () == 4
112+ input_tensor .dim () != 3
112113 ), "All Conv1D should be converted to Conv2D in CanonicalizeConv,"
114+ assert input_tensor .dim () in {
115+ 4 ,
116+ 5 ,
117+ }, "Only Conv2d and Conv3d is supported in conv builder,"
118+
119+ is_conv2d = input_tensor .dim () == 4
113120 input_tensor_wrapper = self .define_tensor (
114121 input_node ,
115122 node ,
@@ -120,9 +127,15 @@ def define_node(
120127
121128 filter_node = self .get_node (node .args [1 ])
122129 filter_tensor = get_parameter (filter_node , self .edge_program )
123- # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
130+ # weight of pytorch OIHW(conv2d) / OIDHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d),
131+ # yet QNN is HWIO or DHWIO
124132 is_transpose_conv = cast (bool , node .args [6 ])
125- filter_axis_order = (2 , 3 , 0 , 1 ) if is_transpose_conv else (2 , 3 , 1 , 0 )
133+ if is_conv2d :
134+ filter_axis_order = (2 , 3 , 0 , 1 ) if is_transpose_conv else (2 , 3 , 1 , 0 )
135+ else :
136+ filter_axis_order = (
137+ (2 , 3 , 4 , 0 , 1 ) if is_transpose_conv else (2 , 3 , 4 , 1 , 0 )
138+ )
126139 filter_tensor = filter_tensor .permute (dims = filter_axis_order ).contiguous ()
127140 filter_tensor_wrapper = self .define_tensor (
128141 filter_node ,
@@ -132,7 +145,6 @@ def define_node(
132145 nodes_to_wrappers ,
133146 )
134147 conv_input_tensors = [input_tensor_wrapper , filter_tensor_wrapper ]
135-
136148 if node .args [2 ] is not None :
137149 bias_node = self .get_node (node .args [2 ])
138150 bias_tensor = get_parameter (bias_node , self .edge_program )
@@ -159,11 +171,10 @@ def define_node(
159171 padding = cast (List [int ], node .args [4 ])
160172 dilation = cast (List [int ], node .args [5 ])
161173 output_padding = cast (List [int ], node .args [7 ])
162-
163174 groups = cast (int , node .args [8 ])
164- # Qnn filter tensor is (H, W, Cin, Cout)
165- group_input_channels = filter_tensor .shape [2 ]
166- group_output_channels = int (filter_tensor .shape [3 ] / groups )
175+ # Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout)
176+ group_input_channels = filter_tensor .shape [- 2 ]
177+ group_output_channels = int (filter_tensor .shape [- 1 ] / groups )
167178 # 1) groups = input_channels (i.e. group_input_channels = 1)
168179 # 2) output_channels is a positive integer multiple of input channels
169180 # TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1
@@ -175,18 +186,23 @@ def define_node(
175186 )
176187 if len (padding ) == 1 :
177188 padding = padding + padding
189+ padding = [[x , x ] for x in padding ]
178190
179191 stride_shape = [len (stride )]
180- padding_shape = [2 , 2 ]
192+ padding_shape = [len ( padding ), len ( padding [ 0 ]) ]
181193 dilation_shape = [len (dilation )]
182194 output_padding_shape = [len (output_padding )]
183195
184- if is_depthwise_conv :
196+ if is_transpose_conv :
197+ assert all (
198+ val == 1 for val in dilation
199+ ), "CanonicalizeConv pass should perform dilate for transpose_conv."
200+ op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d
201+ elif is_depthwise_conv :
202+ assert is_conv2d , "DepthWise only supports Conv2d"
185203 op_class = OpDepthWiseConv2d
186- elif is_transpose_conv :
187- op_class = OpTransposeConv2d
188204 else :
189- op_class = OpConv2d
205+ op_class = OpConv2d if is_conv2d else OpConv3d
190206
191207 conv_op = PyQnnWrapper .PyQnnOpWrapper (
192208 node .name ,
0 commit comments