66 * LICENSE file in the root directory of this source tree.
77 */
88
9+ #include < executorch/kernels/portable/cpu/util/dtype_util.h>
910#include < executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1011#include < executorch/runtime/kernel/kernel_includes.h>
1112#include < executorch/runtime/platform/assert.h>
@@ -34,17 +35,22 @@ namespace {
3435 * the memory level, thereby increasing the speed of memory IO as
3536 * well as reducing the number of cache misses.
3637 */
37- template <typename CTYPE_IN, typename CTYPE_OUT>
38- void cumsum_tensors (const Tensor& self, int64_t dim, Tensor& out) {
38+ template <typename CTYPE_OUT, typename LoadFn = CTYPE_OUT (*)(const void *)>
39+ void cumsum_tensors (
40+ const Tensor& self,
41+ LoadFn load_self,
42+ int64_t dim,
43+ Tensor& out) {
3944 if (self.numel () == 0 ) {
4045 return ;
4146 }
4247
43- const CTYPE_IN* input_data_base = self.const_data_ptr <CTYPE_IN>();
48+ const char * const input_data_base =
49+ reinterpret_cast <const char *>(self.const_data_ptr ());
4450 CTYPE_OUT* output_data_base = out.mutable_data_ptr <CTYPE_OUT>();
4551
4652 if (self.dim () == 0 ) {
47- output_data_base[0 ] = input_data_base[0 ];
53+ output_data_base[0 ] = load_self (& input_data_base[0 ]) ;
4854 return ;
4955 }
5056
@@ -57,15 +63,16 @@ void cumsum_tensors(const Tensor& self, int64_t dim, Tensor& out) {
5763
5864 for (size_t idx = 0 ; idx < trailing_dims; idx++) {
5965 output_data_base[start_loc + idx] =
60- static_cast <CTYPE_OUT>( input_data_base[start_loc + idx]);
66+ load_self (& input_data_base[( start_loc + idx) * self. element_size () ]);
6167 }
6268
6369 for (size_t j = 1 ; j < dim_size; j++) {
6470 size_t cur_round_base = start_loc + j * trailing_dims;
6571 size_t prev_round_base = start_loc + (j - 1 ) * trailing_dims;
6672 for (size_t idx = 0 ; idx < trailing_dims; idx++) {
6773 output_data_base[cur_round_base + idx] =
68- static_cast <CTYPE_OUT>(input_data_base[cur_round_base + idx]) +
74+ load_self (&input_data_base
75+ [(cur_round_base + idx) * self.element_size ()]) +
6976 output_data_base[prev_round_base + idx];
7077 }
7178 }
@@ -101,13 +108,15 @@ Tensor& cumsum_out(
101108
102109 dim = (self.dim () == 0 ) ? 0 : dim < 0 ? dim + self.dim () : dim;
103110
104- ET_SWITCH_REAL_TYPES_AND (
105- Bool, self.scalar_type (), ctx, " cumsum" , CTYPE_SELF, [&] {
106- ET_SWITCH_REAL_TYPES_AND (
107- Bool, out.scalar_type (), ctx, " cumsum" , CTYPE_OUT, [&] {
108- cumsum_tensors<CTYPE_SELF, CTYPE_OUT>(self, dim, out);
109- });
110- });
111+ // @lint-ignore CLANGTIDY facebook-hte-CArray
112+ static constexpr const char op_name[] = " cumsum.out" ;
113+
114+ ET_SWITCH_REALHBBF16_TYPES (out.scalar_type (), ctx, op_name, CTYPE_OUT, [&] {
115+ const auto load_self =
116+ utils::internal::get_load_to_common_fn<CTYPE_OUT, op_name>(
117+ self, utils::SupportedTensorDtypes::REALHBBF16);
118+ cumsum_tensors<CTYPE_OUT>(self, load_self, dim, out);
119+ });
111120
112121 return out;
113122}
0 commit comments