@@ -89,67 +89,101 @@ void xa_opt_quantized_conv_nchw_asym8sxsym8s_asym8s(
89
89
90
90
WORD32 scratch_size = 0 ;
91
91
92
- if (groups == 1 ) {
93
- WORD32 out_data_format = 1 ;
94
-
95
- WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory (
96
- ctx,
97
- ((batches * input_channels * input_height * input_width) + 8 ) *
98
- sizeof (WORD8));
99
-
100
- WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory (
101
- ctx,
102
- ((out_channels * kernel_channels * kernel_height * kernel_width) + 8 ) *
103
- sizeof (WORD8));
104
-
105
- WORD8* pin = (WORD8*)ALIGN_PTR (ptr1, 8 );
106
- WORD8* pkernel = (WORD8*)ALIGN_PTR (ptr2, 8 );
107
-
108
- WORD32 p_inp_shape[kNnlibMaxDim ];
109
- p_inp_shape[0 ] = input.size (0 );
110
- p_inp_shape[1 ] = input_channels;
111
- p_inp_shape[2 ] = input_height;
112
- p_inp_shape[3 ] = input_width;
113
-
114
- WORD32 p_out_shape[kNnlibMaxDim ];
115
- p_out_shape[0 ] = input.size (0 );
116
- p_out_shape[1 ] = input_height;
117
- p_out_shape[2 ] = input_width;
118
- p_out_shape[3 ] = input_channels;
119
-
120
- WORD32 p_permute_vec[kNnlibMaxDim ] = {0 , 2 , 3 , 1 };
121
-
122
- xa_nn_transpose_8_8 (
123
- pin,
124
- p_out_shape,
125
- p_inp,
126
- p_inp_shape,
127
- p_permute_vec,
128
- kNnlibMaxDim ,
129
- kNnlibMaxDim );
130
-
131
- WORD32 p_inp_shape1[kNnlibMaxDim ];
132
- p_inp_shape1[0 ] = out_channels;
133
- p_inp_shape1[1 ] = kernel_channels;
134
- p_inp_shape1[2 ] = kernel_height;
135
- p_inp_shape1[3 ] = kernel_width;
136
-
137
- WORD32 p_out_shape1[kNnlibMaxDim ];
138
- p_out_shape1[0 ] = out_channels;
139
- p_out_shape1[1 ] = kernel_height;
140
- p_out_shape1[2 ] = kernel_width;
141
- p_out_shape1[3 ] = kernel_channels;
142
-
143
- xa_nn_transpose_8_8 (
92
+ ET_CHECK_MSG (groups == 1 , " Only groups=1 supported for regular convolution" );
93
+ WORD32 out_data_format = 1 ;
94
+
95
+ WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory (
96
+ ctx,
97
+ ((batches * input_channels * input_height * input_width) + 8 ) *
98
+ sizeof (WORD8));
99
+
100
+ WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory (
101
+ ctx,
102
+ ((out_channels * kernel_channels * kernel_height * kernel_width) + 8 ) *
103
+ sizeof (WORD8));
104
+
105
+ WORD8* pin = (WORD8*)ALIGN_PTR (ptr1, 8 );
106
+ WORD8* pkernel = (WORD8*)ALIGN_PTR (ptr2, 8 );
107
+
108
+ WORD32 p_inp_shape[kNnlibMaxDim ];
109
+ p_inp_shape[0 ] = input.size (0 );
110
+ p_inp_shape[1 ] = input_channels;
111
+ p_inp_shape[2 ] = input_height;
112
+ p_inp_shape[3 ] = input_width;
113
+
114
+ WORD32 p_out_shape[kNnlibMaxDim ];
115
+ p_out_shape[0 ] = input.size (0 );
116
+ p_out_shape[1 ] = input_height;
117
+ p_out_shape[2 ] = input_width;
118
+ p_out_shape[3 ] = input_channels;
119
+
120
+ WORD32 p_permute_vec[kNnlibMaxDim ] = {0 , 2 , 3 , 1 };
121
+
122
+ xa_nn_transpose_8_8 (
123
+ pin,
124
+ p_out_shape,
125
+ p_inp,
126
+ p_inp_shape,
127
+ p_permute_vec,
128
+ kNnlibMaxDim ,
129
+ kNnlibMaxDim );
130
+
131
+ WORD32 p_inp_shape1[kNnlibMaxDim ];
132
+ p_inp_shape1[0 ] = out_channels;
133
+ p_inp_shape1[1 ] = kernel_channels;
134
+ p_inp_shape1[2 ] = kernel_height;
135
+ p_inp_shape1[3 ] = kernel_width;
136
+
137
+ WORD32 p_out_shape1[kNnlibMaxDim ];
138
+ p_out_shape1[0 ] = out_channels;
139
+ p_out_shape1[1 ] = kernel_height;
140
+ p_out_shape1[2 ] = kernel_width;
141
+ p_out_shape1[3 ] = kernel_channels;
142
+
143
+ xa_nn_transpose_8_8 (
144
+ pkernel,
145
+ p_out_shape1,
146
+ p_kernel,
147
+ p_inp_shape1,
148
+ p_permute_vec,
149
+ kNnlibMaxDim ,
150
+ kNnlibMaxDim );
151
+
152
+ scratch_size = xa_nn_conv2d_getsize (
153
+ input_height,
154
+ input_width,
155
+ input_channels,
156
+ kernel_height,
157
+ kernel_width,
158
+ kernel_channels,
159
+ dilation_height,
160
+ dilation_width,
161
+ y_stride,
162
+ y_padding,
163
+ x_stride,
164
+ x_padding,
165
+ out_height,
166
+ out_width,
167
+ out_channels,
168
+ inp_precision,
169
+ kernel_precision,
170
+ out_data_format);
171
+
172
+ scratch_size = scratch_size < 0 ? 0 : scratch_size;
173
+
174
+ ptr_scratch = (WORD32*)kernels::allocate_temp_memory (ctx, scratch_size);
175
+
176
+ p_scratch = (pVOID)ALIGN_PTR (ptr_scratch, 8 );
177
+
178
+ for (int _n = 0 ; _n < batches; _n++) {
179
+ WORD8* in_batch = pin + _n * input_channels * input_height * input_width;
180
+ WORD8* out_batch = p_out + _n * out_channels * out_height * out_width;
181
+
182
+ xa_nn_conv2d_per_chan_sym8sxasym8s (
183
+ out_batch,
184
+ in_batch,
144
185
pkernel,
145
- p_out_shape1,
146
- p_kernel,
147
- p_inp_shape1,
148
- p_permute_vec,
149
- kNnlibMaxDim ,
150
- kNnlibMaxDim );
151
-
152
- scratch_size = xa_nn_conv2d_getsize (
186
+ p_bias,
153
187
input_height,
154
188
input_width,
155
189
input_channels,
@@ -158,59 +192,20 @@ void xa_opt_quantized_conv_nchw_asym8sxsym8s_asym8s(
158
192
kernel_channels,
159
193
dilation_height,
160
194
dilation_width,
161
- y_stride,
162
- y_padding,
195
+ out_channels,
163
196
x_stride,
197
+ y_stride,
164
198
x_padding,
199
+ y_padding,
165
200
out_height,
166
201
out_width,
167
- out_channels,
168
- inp_precision,
169
- kernel_precision,
170
- out_data_format);
171
-
172
- scratch_size = scratch_size < 0 ? 0 : scratch_size;
173
-
174
- ptr_scratch = (WORD32*)kernels::allocate_temp_memory (ctx, scratch_size);
175
-
176
- p_scratch = (pVOID)ALIGN_PTR (ptr_scratch, 8 );
177
-
178
- for (int _n = 0 ; _n < batches; _n++) {
179
- WORD8* in_batch = pin + _n * input_channels * input_height * input_width;
180
- WORD8* out_batch = p_out + _n * out_channels * out_height * out_width;
181
-
182
- xa_nn_conv2d_per_chan_sym8sxasym8s (
183
- out_batch,
184
- in_batch,
185
- pkernel,
186
- p_bias,
187
- input_height,
188
- input_width,
189
- input_channels,
190
- kernel_height,
191
- kernel_width,
192
- kernel_channels,
193
- dilation_height,
194
- dilation_width,
195
- out_channels,
196
- x_stride,
197
- y_stride,
198
- x_padding,
199
- y_padding,
200
- out_height,
201
- out_width,
202
- input_zero_bias,
203
- out_multiplier32,
204
- out_shift32,
205
- out_zero_bias,
206
- out_data_format,
207
- p_scratch);
208
- }
209
- return ;
202
+ input_zero_bias,
203
+ out_multiplier32,
204
+ out_shift32,
205
+ out_zero_bias,
206
+ out_data_format,
207
+ p_scratch);
210
208
}
211
-
212
- // Depthwise convolutions are now handled by specialized operators
213
- ET_CHECK_MSG (groups == 1 , " Only groups=1 supported for regular convolution" );
214
209
}
215
210
216
211
void quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out (
0 commit comments