From a6435aed3cbf11e227ca8a1dfa6a19dd709c7946 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 8 Oct 2024 14:55:31 -0700 Subject: [PATCH] [ET][Portable][Build Size] Reduce build size of op_cumsum 98 K -> 9 K Differential Revision: [D63997235](https://our.internmc.facebook.com/intern/diff/D63997235/) [ghstack-poisoned] --- kernels/portable/cpu/op_cumsum.cpp | 34 ++++++++++++------- .../kernels/portable/op_registration_util.bzl | 1 + 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/kernels/portable/cpu/op_cumsum.cpp b/kernels/portable/cpu/op_cumsum.cpp index 7b3eae5fe35..a061b57bda2 100644 --- a/kernels/portable/cpu/op_cumsum.cpp +++ b/kernels/portable/cpu/op_cumsum.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -34,17 +35,22 @@ namespace { * the memory level, thereby increasing the speed of memory IO as * well as reducing the number of cache misses. */ -template -void cumsum_tensors(const Tensor& self, int64_t dim, Tensor& out) { +template +void cumsum_tensors( + const Tensor& self, + LoadFn load_self, + int64_t dim, + Tensor& out) { if (self.numel() == 0) { return; } - const CTYPE_IN* input_data_base = self.const_data_ptr(); + const char* const input_data_base = + reinterpret_cast(self.const_data_ptr()); CTYPE_OUT* output_data_base = out.mutable_data_ptr(); if (self.dim() == 0) { - output_data_base[0] = input_data_base[0]; + output_data_base[0] = load_self(&input_data_base[0]); return; } @@ -57,7 +63,7 @@ void cumsum_tensors(const Tensor& self, int64_t dim, Tensor& out) { for (size_t idx = 0; idx < trailing_dims; idx++) { output_data_base[start_loc + idx] = - static_cast(input_data_base[start_loc + idx]); + load_self(&input_data_base[(start_loc + idx) * self.element_size()]); } for (size_t j = 1; j < dim_size; j++) { @@ -65,7 +71,8 @@ void cumsum_tensors(const Tensor& self, int64_t dim, Tensor& out) { size_t prev_round_base = start_loc + (j - 1) * trailing_dims; for (size_t idx = 0; idx < trailing_dims; idx++) { output_data_base[cur_round_base + idx] = - static_cast(input_data_base[cur_round_base + idx]) + + load_self(&input_data_base + [(cur_round_base + idx) * self.element_size()]) + output_data_base[prev_round_base + idx]; } } @@ -101,13 +108,14 @@ Tensor& cumsum_out( dim = (self.dim() == 0) ? 0 : dim < 0 ? dim + self.dim() : dim; - ET_SWITCH_REAL_TYPES_AND( - Bool, self.scalar_type(), ctx, "cumsum", CTYPE_SELF, [&] { - ET_SWITCH_REAL_TYPES_AND( - Bool, out.scalar_type(), ctx, "cumsum", CTYPE_OUT, [&] { - cumsum_tensors(self, dim, out); - }); - }); + static constexpr const char op_name[] = "cumsum.out"; + + ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] { + const auto load_self = + utils::internal::get_load_to_common_fn( + self, utils::SupportedTensorDtypes::REALHBBF16); + cumsum_tensors(self, load_self, dim, out); + }); return out; } diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 538e20bd126..32b9cec7e24 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -447,6 +447,7 @@ ATEN_OPS = ( op_target( name = "op_cumsum", deps = [ + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util", "//executorch/kernels/portable/cpu/util:kernel_ops_util",