@@ -273,6 +273,134 @@ __attribute__((noinline)) void conv2d_nhwc_core_generic(
273273 }
274274}
275275
276+ void convolution_nchw (
277+ const Tensor& input,
278+ const Tensor& weight,
279+ const Tensor& bias,
280+ IntArrayRef stride,
281+ IntArrayRef padding,
282+ IntArrayRef dilation,
283+ int16_t groups,
284+ Tensor& output) {
285+ bool conv1d = input.dim () == 3 ;
286+ // input = [n, c, h, w]
287+ const int n = input.size (0 );
288+ const int c = input.size (1 );
289+ const int h = conv1d ? 1 : input.size (2 );
290+ const int w = conv1d ? input.size (2 ) : input.size (3 );
291+ // weight = [oc, wc, wh, ww]
292+ const int oc = weight.size (0 );
293+ const int wc = weight.size (1 );
294+ const int wh = conv1d ? 1 : weight.size (2 );
295+ const int ww = conv1d ? weight.size (2 ) : weight.size (3 );
296+ // output = [n, oc, oh, ow]
297+ const int oh = conv1d ? 1 : output.size (2 );
298+ const int ow = conv1d ? output.size (2 ) : output.size (3 );
299+
300+ float * __restrict__ p_out = output.mutable_data_ptr <float >();
301+ const float * __restrict__ p_in = input.const_data_ptr <float >();
302+ const float * __restrict__ p_weight = weight.const_data_ptr <float >();
303+ const float * __restrict__ p_bias = bias.const_data_ptr <float >();
304+
305+ conv2d_nchw_core_generic<>(
306+ p_in,
307+ p_weight,
308+ p_bias,
309+ p_out,
310+ n,
311+ c,
312+ h,
313+ w,
314+ oc,
315+ wc,
316+ wh,
317+ ww,
318+ oh,
319+ ow,
320+ conv1d ? 1 : stride[0 ],
321+ conv1d ? stride[0 ] : stride[1 ],
322+ conv1d ? 0 : padding[0 ],
323+ conv1d ? padding[0 ] : padding[1 ],
324+ conv1d ? 1 : dilation[0 ],
325+ conv1d ? dilation[0 ] : dilation[1 ],
326+ groups);
327+ }
328+
329+ void convolution_nhwc (
330+ const Tensor& input,
331+ const Tensor& weight,
332+ const Tensor& bias,
333+ IntArrayRef stride,
334+ IntArrayRef padding,
335+ IntArrayRef dilation,
336+ int16_t groups,
337+ Tensor& output) {
338+ bool conv1d = input.dim () == 3 ;
339+ // input = [n, h, w, c]
340+ const int n = input.size (0 );
341+ const int h = conv1d ? 1 : input.size (1 );
342+ const int w = conv1d ? input.size (1 ) : input.size (2 );
343+ const int c = conv1d ? input.size (2 ) : input.size (3 );
344+
345+ // weight = [oc, wh, ww, wc]
346+ const int oc = weight.size (0 );
347+ const int wh = conv1d ? 1 : weight.size (1 );
348+ const int ww = conv1d ? weight.size (1 ) : weight.size (2 );
349+ const int wc = conv1d ? weight.size (2 ) : weight.size (3 );
350+
351+ // output = [n, oh, ow, oc]
352+ const int oh = conv1d ? 1 : output.size (1 );
353+ const int ow = conv1d ? output.size (1 ) : output.size (2 );
354+
355+ float * __restrict__ p_out = output.mutable_data_ptr <float >();
356+ const float * __restrict__ p_in = input.const_data_ptr <float >();
357+ const float * __restrict__ p_weight = weight.const_data_ptr <float >();
358+ const float * __restrict__ p_bias = bias.const_data_ptr <float >();
359+
360+ conv2d_nhwc_core_generic<>(
361+ p_in,
362+ p_weight,
363+ p_bias,
364+ p_out,
365+ n,
366+ h,
367+ w,
368+ c,
369+ oc,
370+ wh,
371+ ww,
372+ wc,
373+ oh,
374+ ow,
375+ conv1d ? 1 : stride[0 ],
376+ conv1d ? stride[0 ] : stride[1 ],
377+ conv1d ? 0 : padding[0 ],
378+ conv1d ? padding[0 ] : padding[1 ],
379+ conv1d ? 1 : dilation[0 ],
380+ conv1d ? dilation[0 ] : dilation[1 ],
381+ groups);
382+ }
383+
384+ void convolution_out (
385+ __ET_UNUSED KernelRuntimeContext& ctx,
386+ const Tensor& input,
387+ const Tensor& weight,
388+ const Tensor& bias,
389+ IntArrayRef stride,
390+ IntArrayRef padding,
391+ IntArrayRef dilation,
392+ int64_t groups,
393+ bool channel_last,
394+ Tensor& output) {
395+ if (channel_last) {
396+ convolution_nhwc (
397+ input, weight, bias, stride, padding, dilation, groups, output);
398+ } else {
399+ convolution_nchw (
400+ input, weight, bias, stride, padding, dilation, groups, output);
401+ }
402+ }
403+
276404// The quantized convolution kernel. in_scale and weight_scale are implicit in
277405// bias_scale, since it is a product of the two. The kernel will branch to
278406// quantized::conv1d or quantized::conv2d based on the dimensionality of
0 commit comments