88
99#include < cstring>
1010
11+ #include < executorch/kernels/portable/cpu/util/dtype_util.h>
1112#include < executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1213#include < executorch/runtime/core/exec_aten/util/dim_order_util.h>
1314#include < executorch/runtime/kernel/kernel_includes.h>
@@ -32,15 +33,17 @@ namespace {
3233 * in_C_per_group x in_H x in_W, to compute an out channel of size 1 x out_H x
3334 * out_W.
3435 */
35- template <typename CTYPE, typename CTYPE_BIAS >
36+ template <typename CTYPE, typename LoadFn = CTYPE (*)( const void *) >
3637void conv2d_impl (
3738 const CTYPE* const in_ptr,
3839 SizesArrayRef in_sizes,
3940 StridesArrayRef in_strides,
4041 const CTYPE* const w_ptr,
4142 SizesArrayRef w_sizes,
4243 StridesArrayRef w_strides,
43- const CTYPE_BIAS* const bias_ptr,
44+ const exec_aten::optional<Tensor>& bias,
45+ const char * const bias_ptr,
46+ LoadFn load_bias,
4447 IntArrayRef stride,
4548 IntArrayRef padding,
4649 IntArrayRef dilation,
@@ -128,7 +131,7 @@ void conv2d_impl(
128131 }
129132
130133 if (bias_ptr != nullptr ) {
131- accum += convert<CTYPE, CTYPE_BIAS>( bias_ptr[out_c]);
134+ accum += load_bias (& bias_ptr[out_c * bias. value (). element_size () ]);
132135 }
133136 size_t out_idx =
134137 calculate_linear_index (out_coord, out_strides.data (), 4 );
@@ -185,11 +188,12 @@ void conv2d_impl(
185188 }
186189}
187190
188- template <typename CTYPE, typename CTYPE_BIAS >
191+ template <typename CTYPE, typename LoadFn = CTYPE (*)( const void *) >
189192void convolution_wrapper (
190193 const Tensor& in,
191194 const Tensor& weight,
192195 const exec_aten::optional<Tensor>& bias,
196+ LoadFn load_bias,
193197 IntArrayRef stride,
194198 IntArrayRef padding,
195199 IntArrayRef dilation,
@@ -280,8 +284,9 @@ void convolution_wrapper(
280284 CTYPE* const out_ptr = out.mutable_data_ptr <CTYPE>();
281285 const CTYPE* const in_ptr = in.const_data_ptr <CTYPE>();
282286 const CTYPE* const w_ptr = weight.const_data_ptr <CTYPE>();
283- const CTYPE_BIAS* const bias_ptr =
284- bias.has_value () ? bias.value ().const_data_ptr <CTYPE_BIAS>() : nullptr ;
287+ const char * const bias_ptr = bias.has_value ()
288+ ? reinterpret_cast <const char *>(bias.value ().const_data_ptr ())
289+ : nullptr ;
285290
286291 size_t out_N = out.size (0 );
287292 size_t out_C = out.size (1 );
@@ -296,8 +301,9 @@ void convolution_wrapper(
296301 } else {
297302 // If bias is present, we initialize the output to the bias value
298303 for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
299- out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(
300- bias_ptr[(out_ix / out_strides[1 ]) % out_C]);
304+ out_ptr[out_ix] = load_bias (&bias_ptr
305+ [((out_ix / out_strides[1 ]) % out_C) *
306+ bias.value ().element_size ()]);
301307 }
302308 }
303309 }
@@ -316,7 +322,9 @@ void convolution_wrapper(
316322 w_ptr,
317323 weight_sizes,
318324 {weight_strides, 4 },
325+ bias,
319326 bias_ptr,
327+ load_bias,
320328 stride_,
321329 padding_,
322330 dilation_,
@@ -398,19 +406,25 @@ Tensor& convolution_out(
398406 return out;
399407 }
400408
401- ScalarType in_type = in.scalar_type ();
402- ScalarType bias_type = in_type;
403- if (bias.has_value ()) {
404- bias_type = bias.value ().scalar_type ();
405- }
406-
407- constexpr auto name = " convolution.out" ;
408-
409- ET_SWITCH_REALH_TYPES (in_type, ctx, name, CTYPE, [&]() {
410- ET_SWITCH_REALHB_TYPES (bias_type, ctx, name, CTYPE_BIAS, [&]() {
411- convolution_wrapper<CTYPE, CTYPE_BIAS>(
412- in, weight, bias, stride, padding, dilation, transposed, groups, out);
413- });
409+ // @lint-ignore CLANGTIDY facebook-hte-CArray
410+ static constexpr const char name[] = " convolution.out" ;
411+
412+ ET_SWITCH_REALH_TYPES (in.scalar_type (), ctx, name, CTYPE, [&]() {
413+ const auto load_bias = bias.has_value ()
414+ ? utils::internal::get_load_to_common_fn<CTYPE, name>(
415+ bias.value (), utils::SupportedTensorDtypes::REALHBF16)
416+ : nullptr ;
417+ convolution_wrapper<CTYPE>(
418+ in,
419+ weight,
420+ bias,
421+ load_bias,
422+ stride,
423+ padding,
424+ dilation,
425+ transposed,
426+ groups,
427+ out);
414428 });
415429
416430 return out;
0 commit comments