From 7df1701ae04eb495012cc70877778c036727bdde Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 8 Oct 2024 14:55:21 -0700 Subject: [PATCH] [ET][Portable][Build Size] Reduce build size of op_convolution 500 K -> 52 K Differential Revision: [D63994876](https://our.internmc.facebook.com/intern/diff/D63994876/) [ghstack-poisoned] --- kernels/portable/cpu/op_convolution.cpp | 55 ++++++++++++------- .../kernels/portable/op_registration_util.bzl | 1 + 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/kernels/portable/cpu/op_convolution.cpp b/kernels/portable/cpu/op_convolution.cpp index af6b164f301..d57fa4e3b82 100644 --- a/kernels/portable/cpu/op_convolution.cpp +++ b/kernels/portable/cpu/op_convolution.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -32,7 +33,7 @@ namespace { * in_C_per_group x in_H x in_W, to compute an out channel of size 1 x out_H x * out_W. */ -template +template void conv2d_impl( const CTYPE* const in_ptr, SizesArrayRef in_sizes, @@ -40,7 +41,9 @@ void conv2d_impl( const CTYPE* const w_ptr, SizesArrayRef w_sizes, StridesArrayRef w_strides, - const CTYPE_BIAS* const bias_ptr, + const exec_aten::optional& bias, + const char* const bias_ptr, + LoadFn load_bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, @@ -128,7 +131,7 @@ void conv2d_impl( } if (bias_ptr != nullptr) { - accum += convert(bias_ptr[out_c]); + accum += load_bias(&bias_ptr[out_c * bias.value().element_size()]); } size_t out_idx = calculate_linear_index(out_coord, out_strides.data(), 4); @@ -185,11 +188,12 @@ void conv2d_impl( } } -template +template void convolution_wrapper( const Tensor& in, const Tensor& weight, const exec_aten::optional& bias, + LoadFn load_bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, @@ -280,8 +284,9 @@ void convolution_wrapper( CTYPE* const out_ptr = out.mutable_data_ptr(); const CTYPE* const in_ptr = in.const_data_ptr(); const CTYPE* const w_ptr = weight.const_data_ptr(); - const CTYPE_BIAS* const bias_ptr = - bias.has_value() ? bias.value().const_data_ptr() : nullptr; + const char* const bias_ptr = bias.has_value() + ? reinterpret_cast(bias.value().const_data_ptr()) + : nullptr; size_t out_N = out.size(0); size_t out_C = out.size(1); @@ -296,8 +301,9 @@ void convolution_wrapper( } else { // If bias is present, we initialize the output to the bias value for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { - out_ptr[out_ix] = convert( - bias_ptr[(out_ix / out_strides[1]) % out_C]); + out_ptr[out_ix] = load_bias(&bias_ptr + [((out_ix / out_strides[1]) % out_C) * + bias.value().element_size()]); } } } @@ -316,7 +322,9 @@ void convolution_wrapper( w_ptr, weight_sizes, {weight_strides, 4}, + bias, bias_ptr, + load_bias, stride_, padding_, dilation_, @@ -398,19 +406,24 @@ Tensor& convolution_out( return out; } - ScalarType in_type = in.scalar_type(); - ScalarType bias_type = in_type; - if (bias.has_value()) { - bias_type = bias.value().scalar_type(); - } - - constexpr auto name = "convolution.out"; - - ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() { - ET_SWITCH_REALHB_TYPES(bias_type, ctx, name, CTYPE_BIAS, [&]() { - convolution_wrapper( - in, weight, bias, stride, padding, dilation, transposed, groups, out); - }); + static constexpr const char name[] = "convolution.out"; + + ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { + const auto load_bias = bias.has_value() + ? utils::internal::get_load_to_common_fn( + bias.value(), utils::SupportedTensorDtypes::REALHBF16) + : nullptr; + convolution_wrapper( + in, + weight, + bias, + load_bias, + stride, + padding, + dilation, + transposed, + groups, + out); }); return out; diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 92b940e6305..8e9069f8134 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -408,6 +408,7 @@ ATEN_OPS = ( op_target( name = "op_convolution", deps = [ + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:kernel_ops_util", ":vec_ops", ],