@@ -273,134 +273,6 @@ __attribute__((noinline)) void conv2d_nhwc_core_generic(
273273 }
274274}
275275
276- void convolution_nchw (
277- const Tensor& input,
278- const Tensor& weight,
279- const Tensor& bias,
280- IntArrayRef stride,
281- IntArrayRef padding,
282- IntArrayRef dilation,
283- int16_t groups,
284- Tensor& output) {
285- bool conv1d = input.dim () == 3 ;
286- // input = [n, c, h, w]
287- const int n = input.size (0 );
288- const int c = input.size (1 );
289- const int h = conv1d ? 1 : input.size (2 );
290- const int w = conv1d ? input.size (2 ) : input.size (3 );
291- // weight = [oc, wc, wh, ww]
292- const int oc = weight.size (0 );
293- const int wc = weight.size (1 );
294- const int wh = conv1d ? 1 : weight.size (2 );
295- const int ww = conv1d ? weight.size (2 ) : weight.size (3 );
296- // output = [n, oc, oh, ow]
297- const int oh = conv1d ? 1 : output.size (2 );
298- const int ow = conv1d ? output.size (2 ) : output.size (3 );
299-
300- float * __restrict__ p_out = output.mutable_data_ptr <float >();
301- const float * __restrict__ p_in = input.const_data_ptr <float >();
302- const float * __restrict__ p_weight = weight.const_data_ptr <float >();
303- const float * __restrict__ p_bias = bias.const_data_ptr <float >();
304-
305- conv2d_nchw_core_generic<>(
306- p_in,
307- p_weight,
308- p_bias,
309- p_out,
310- n,
311- c,
312- h,
313- w,
314- oc,
315- wc,
316- wh,
317- ww,
318- oh,
319- ow,
320- conv1d ? 1 : stride[0 ],
321- conv1d ? stride[0 ] : stride[1 ],
322- conv1d ? 0 : padding[0 ],
323- conv1d ? padding[0 ] : padding[1 ],
324- conv1d ? 1 : dilation[0 ],
325- conv1d ? dilation[0 ] : dilation[1 ],
326- groups);
327- }
328-
329- void convolution_nhwc (
330- const Tensor& input,
331- const Tensor& weight,
332- const Tensor& bias,
333- IntArrayRef stride,
334- IntArrayRef padding,
335- IntArrayRef dilation,
336- int16_t groups,
337- Tensor& output) {
338- bool conv1d = input.dim () == 3 ;
339- // input = [n, h, w, c]
340- const int n = input.size (0 );
341- const int h = conv1d ? 1 : input.size (1 );
342- const int w = conv1d ? input.size (1 ) : input.size (2 );
343- const int c = conv1d ? input.size (2 ) : input.size (3 );
344-
345- // weight = [oc, wh, ww, wc]
346- const int oc = weight.size (0 );
347- const int wh = conv1d ? 1 : weight.size (1 );
348- const int ww = conv1d ? weight.size (1 ) : weight.size (2 );
349- const int wc = conv1d ? weight.size (2 ) : weight.size (3 );
350-
351- // output = [n, oh, ow, oc]
352- const int oh = conv1d ? 1 : output.size (1 );
353- const int ow = conv1d ? output.size (1 ) : output.size (2 );
354-
355- float * __restrict__ p_out = output.mutable_data_ptr <float >();
356- const float * __restrict__ p_in = input.const_data_ptr <float >();
357- const float * __restrict__ p_weight = weight.const_data_ptr <float >();
358- const float * __restrict__ p_bias = bias.const_data_ptr <float >();
359-
360- conv2d_nhwc_core_generic<>(
361- p_in,
362- p_weight,
363- p_bias,
364- p_out,
365- n,
366- h,
367- w,
368- c,
369- oc,
370- wh,
371- ww,
372- wc,
373- oh,
374- ow,
375- conv1d ? 1 : stride[0 ],
376- conv1d ? stride[0 ] : stride[1 ],
377- conv1d ? 0 : padding[0 ],
378- conv1d ? padding[0 ] : padding[1 ],
379- conv1d ? 1 : dilation[0 ],
380- conv1d ? dilation[0 ] : dilation[1 ],
381- groups);
382- }
383-
384- void convolution_out (
385- __ET_UNUSED KernelRuntimeContext& ctx,
386- const Tensor& input,
387- const Tensor& weight,
388- const Tensor& bias,
389- IntArrayRef stride,
390- IntArrayRef padding,
391- IntArrayRef dilation,
392- int64_t groups,
393- bool channel_last,
394- Tensor& output) {
395- if (channel_last) {
396- convolution_nhwc (
397- input, weight, bias, stride, padding, dilation, groups, output);
398- } else {
399- convolution_nchw (
400- input, weight, bias, stride, padding, dilation, groups, output);
401- }
402- }
403-
404276// The quantized convolution kernel. in_scale and weight_scale are implicit in
405277// bias_scale, since it is a product of the two. The kernel will branch to
406278// quantized::conv1d or quantized::conv2d based on the dimensionality of
0 commit comments