@@ -89,67 +89,101 @@ void xa_opt_quantized_conv_nchw_asym8sxsym8s_asym8s(
8989
9090 WORD32 scratch_size = 0 ;
9191
92- if (groups == 1 ) {
93- WORD32 out_data_format = 1 ;
94-
95- WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory (
96- ctx,
97- ((batches * input_channels * input_height * input_width) + 8 ) *
98- sizeof (WORD8));
99-
100- WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory (
101- ctx,
102- ((out_channels * kernel_channels * kernel_height * kernel_width) + 8 ) *
103- sizeof (WORD8));
104-
105- WORD8* pin = (WORD8*)ALIGN_PTR (ptr1, 8 );
106- WORD8* pkernel = (WORD8*)ALIGN_PTR (ptr2, 8 );
107-
108- WORD32 p_inp_shape[kNnlibMaxDim ];
109- p_inp_shape[0 ] = input.size (0 );
110- p_inp_shape[1 ] = input_channels;
111- p_inp_shape[2 ] = input_height;
112- p_inp_shape[3 ] = input_width;
113-
114- WORD32 p_out_shape[kNnlibMaxDim ];
115- p_out_shape[0 ] = input.size (0 );
116- p_out_shape[1 ] = input_height;
117- p_out_shape[2 ] = input_width;
118- p_out_shape[3 ] = input_channels;
119-
120- WORD32 p_permute_vec[kNnlibMaxDim ] = {0 , 2 , 3 , 1 };
121-
122- xa_nn_transpose_8_8 (
123- pin,
124- p_out_shape,
125- p_inp,
126- p_inp_shape,
127- p_permute_vec,
128- kNnlibMaxDim ,
129- kNnlibMaxDim );
130-
131- WORD32 p_inp_shape1[kNnlibMaxDim ];
132- p_inp_shape1[0 ] = out_channels;
133- p_inp_shape1[1 ] = kernel_channels;
134- p_inp_shape1[2 ] = kernel_height;
135- p_inp_shape1[3 ] = kernel_width;
136-
137- WORD32 p_out_shape1[kNnlibMaxDim ];
138- p_out_shape1[0 ] = out_channels;
139- p_out_shape1[1 ] = kernel_height;
140- p_out_shape1[2 ] = kernel_width;
141- p_out_shape1[3 ] = kernel_channels;
142-
143- xa_nn_transpose_8_8 (
92+ ET_CHECK_MSG (groups == 1 , " Only groups=1 supported for regular convolution" );
93+ WORD32 out_data_format = 1 ;
94+
95+ WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory (
96+ ctx,
97+ ((batches * input_channels * input_height * input_width) + 8 ) *
98+ sizeof (WORD8));
99+
100+ WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory (
101+ ctx,
102+ ((out_channels * kernel_channels * kernel_height * kernel_width) + 8 ) *
103+ sizeof (WORD8));
104+
105+ WORD8* pin = (WORD8*)ALIGN_PTR (ptr1, 8 );
106+ WORD8* pkernel = (WORD8*)ALIGN_PTR (ptr2, 8 );
107+
108+ WORD32 p_inp_shape[kNnlibMaxDim ];
109+ p_inp_shape[0 ] = input.size (0 );
110+ p_inp_shape[1 ] = input_channels;
111+ p_inp_shape[2 ] = input_height;
112+ p_inp_shape[3 ] = input_width;
113+
114+ WORD32 p_out_shape[kNnlibMaxDim ];
115+ p_out_shape[0 ] = input.size (0 );
116+ p_out_shape[1 ] = input_height;
117+ p_out_shape[2 ] = input_width;
118+ p_out_shape[3 ] = input_channels;
119+
120+ WORD32 p_permute_vec[kNnlibMaxDim ] = {0 , 2 , 3 , 1 };
121+
122+ xa_nn_transpose_8_8 (
123+ pin,
124+ p_out_shape,
125+ p_inp,
126+ p_inp_shape,
127+ p_permute_vec,
128+ kNnlibMaxDim ,
129+ kNnlibMaxDim );
130+
131+ WORD32 p_inp_shape1[kNnlibMaxDim ];
132+ p_inp_shape1[0 ] = out_channels;
133+ p_inp_shape1[1 ] = kernel_channels;
134+ p_inp_shape1[2 ] = kernel_height;
135+ p_inp_shape1[3 ] = kernel_width;
136+
137+ WORD32 p_out_shape1[kNnlibMaxDim ];
138+ p_out_shape1[0 ] = out_channels;
139+ p_out_shape1[1 ] = kernel_height;
140+ p_out_shape1[2 ] = kernel_width;
141+ p_out_shape1[3 ] = kernel_channels;
142+
143+ xa_nn_transpose_8_8 (
144+ pkernel,
145+ p_out_shape1,
146+ p_kernel,
147+ p_inp_shape1,
148+ p_permute_vec,
149+ kNnlibMaxDim ,
150+ kNnlibMaxDim );
151+
152+ scratch_size = xa_nn_conv2d_getsize (
153+ input_height,
154+ input_width,
155+ input_channels,
156+ kernel_height,
157+ kernel_width,
158+ kernel_channels,
159+ dilation_height,
160+ dilation_width,
161+ y_stride,
162+ y_padding,
163+ x_stride,
164+ x_padding,
165+ out_height,
166+ out_width,
167+ out_channels,
168+ inp_precision,
169+ kernel_precision,
170+ out_data_format);
171+
172+ scratch_size = scratch_size < 0 ? 0 : scratch_size;
173+
174+ ptr_scratch = (WORD32*)kernels::allocate_temp_memory (ctx, scratch_size);
175+
176+ p_scratch = (pVOID)ALIGN_PTR (ptr_scratch, 8 );
177+
178+ for (int _n = 0 ; _n < batches; _n++) {
179+ WORD8* in_batch = pin + _n * input_channels * input_height * input_width;
180+ WORD8* out_batch = p_out + _n * out_channels * out_height * out_width;
181+
182+ xa_nn_conv2d_per_chan_sym8sxasym8s (
183+ out_batch,
184+ in_batch,
144185 pkernel,
145- p_out_shape1,
146- p_kernel,
147- p_inp_shape1,
148- p_permute_vec,
149- kNnlibMaxDim ,
150- kNnlibMaxDim );
151-
152- scratch_size = xa_nn_conv2d_getsize (
186+ p_bias,
153187 input_height,
154188 input_width,
155189 input_channels,
@@ -158,59 +192,20 @@ void xa_opt_quantized_conv_nchw_asym8sxsym8s_asym8s(
158192 kernel_channels,
159193 dilation_height,
160194 dilation_width,
161- y_stride,
162- y_padding,
195+ out_channels,
163196 x_stride,
197+ y_stride,
164198 x_padding,
199+ y_padding,
165200 out_height,
166201 out_width,
167- out_channels,
168- inp_precision,
169- kernel_precision,
170- out_data_format);
171-
172- scratch_size = scratch_size < 0 ? 0 : scratch_size;
173-
174- ptr_scratch = (WORD32*)kernels::allocate_temp_memory (ctx, scratch_size);
175-
176- p_scratch = (pVOID)ALIGN_PTR (ptr_scratch, 8 );
177-
178- for (int _n = 0 ; _n < batches; _n++) {
179- WORD8* in_batch = pin + _n * input_channels * input_height * input_width;
180- WORD8* out_batch = p_out + _n * out_channels * out_height * out_width;
181-
182- xa_nn_conv2d_per_chan_sym8sxasym8s (
183- out_batch,
184- in_batch,
185- pkernel,
186- p_bias,
187- input_height,
188- input_width,
189- input_channels,
190- kernel_height,
191- kernel_width,
192- kernel_channels,
193- dilation_height,
194- dilation_width,
195- out_channels,
196- x_stride,
197- y_stride,
198- x_padding,
199- y_padding,
200- out_height,
201- out_width,
202- input_zero_bias,
203- out_multiplier32,
204- out_shift32,
205- out_zero_bias,
206- out_data_format,
207- p_scratch);
208- }
209- return ;
202+ input_zero_bias,
203+ out_multiplier32,
204+ out_shift32,
205+ out_zero_bias,
206+ out_data_format,
207+ p_scratch);
210208 }
211-
212- // Depthwise convolutions are now handled by specialized operators
213- ET_CHECK_MSG (groups == 1 , " Only groups=1 supported for regular convolution" );
214209}
215210
216211void quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out (
0 commit comments