Skip to content

Commit 60e6ce3

Browse files
committed
Update
[ghstack-poisoned]
1 parent f278976 commit 60e6ce3

File tree

1 file changed

+60
-32
lines changed

1 file changed

+60
-32
lines changed

kernels/portable/cpu/util/reduce_util.h

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ class ApplyOverDimListPlan {
309309
dim_list,
310310
const int64_t start = 0,
311311
const int64_t end = -1)
312-
: in_(in) {
312+
: dim_list_(dim_list), in_(in) {
313313
ET_CHECK(check_dim_list_is_valid(in, dim_list));
314314
out_numel_ = get_out_numel(in_, dim_list);
315315
if (in.numel() == 0) {
@@ -352,13 +352,22 @@ class ApplyOverDimListPlan {
352352
fn,
353353
in_,
354354
is_in_dim_list_.data(),
355-
get_init_index(in_, dim_list_, out_ix),
355+
get_init_index(in_, dim_list_.value(), out_ix),
356356
ustart_,
357357
uend_);
358358
return;
359359
}
360360
}
361361

362+
const executorch::aten::Tensor& get_input_tensor() const {
363+
return in_;
364+
}
365+
366+
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
367+
get_dim_list() const {
368+
return dim_list_;
369+
}
370+
362371
private:
363372
// Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
364373
size_t ustart_;
@@ -376,7 +385,7 @@ class ApplyOverDimListPlan {
376385
};
377386
ExecutionMode mode_;
378387
size_t out_numel_;
379-
executorch::aten::ArrayRef<int64_t> dim_list_;
388+
executorch::aten::optional<executorch::aten::ArrayRef<int64_t>> dim_list_;
380389
std::array<bool, kTensorDimensionLimit> is_in_dim_list_;
381390
const executorch::aten::Tensor& in_;
382391
};
@@ -482,6 +491,52 @@ std::tuple<CTYPE_OUT, long> map_reduce_over_dim(
482491
return std::tuple<CTYPE_OUT, long>{acc_val, acc_ix};
483492
}
484493

494+
/**
495+
* Execution plan for repeated map_reduce_over_dim_list with the same
496+
* function, input tensor, and dim_list but varying out_ix.
497+
*/
498+
class MapReduceOverDimListPlan {
499+
public:
500+
MapReduceOverDimListPlan(
501+
const executorch::aten::Tensor& in,
502+
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
503+
dim_list)
504+
: plan_(in, dim_list, 1, -1) {
505+
ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty");
506+
}
507+
508+
template <
509+
typename CTYPE_IN,
510+
typename CTYPE_OUT,
511+
typename MapOp,
512+
typename ReduceOp>
513+
CTYPE_OUT execute(
514+
const MapOp& map_fun,
515+
const ReduceOp& reduce_fun,
516+
const size_t out_ix) const {
517+
const size_t init_index =
518+
get_init_index(plan_.get_input_tensor(), plan_.get_dim_list(), out_ix);
519+
520+
const CTYPE_IN* const in_data =
521+
plan_.get_input_tensor().const_data_ptr<CTYPE_IN>();
522+
CTYPE_OUT acc_val = map_fun(in_data[init_index]);
523+
524+
if (plan_.get_input_tensor().numel() == 1) {
525+
return acc_val;
526+
}
527+
528+
plan_.execute(
529+
[&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
530+
acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val);
531+
},
532+
out_ix);
533+
return acc_val;
534+
}
535+
536+
private:
537+
ApplyOverDimListPlan plan_;
538+
};
539+
485540
/**
486541
* Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
487542
* for the output element at index `out_ix`, first applying the map `map_fun`
@@ -517,35 +572,8 @@ CTYPE_OUT map_reduce_over_dim_list(
517572
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
518573
dim_list,
519574
const size_t out_ix) {
520-
ET_CHECK(check_dim_list_is_valid(in, dim_list));
521-
522-
ET_CHECK_MSG(
523-
out_ix < get_out_numel(in, dim_list),
524-
"Out index %zd is out of bounds",
525-
out_ix);
526-
527-
ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty");
528-
529-
const size_t init_index = get_init_index(in, dim_list, out_ix);
530-
531-
const CTYPE_IN* const in_data = in.const_data_ptr<CTYPE_IN>();
532-
CTYPE_OUT acc_val = map_fun(in_data[init_index]);
533-
534-
if (in.numel() == 1) {
535-
return acc_val;
536-
}
537-
538-
apply_over_dim_list(
539-
[&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
540-
acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val);
541-
},
542-
in,
543-
dim_list,
544-
out_ix,
545-
1,
546-
-1);
547-
548-
return acc_val;
575+
MapReduceOverDimListPlan plan(in, dim_list);
576+
return plan.execute<CTYPE_IN, CTYPE_OUT>(map_fun, reduce_fun, out_ix);
549577
}
550578

551579
/**

0 commit comments

Comments
 (0)