Skip to content

Commit 1d43a90

Browse files
committed
Update
[ghstack-poisoned]
1 parent 60e6ce3 commit 1d43a90

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

kernels/portable/cpu/util/reduce_util.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,28 @@ std::tuple<CTYPE, long> reduce_over_dim(
606606
[](CTYPE v) { return v; }, reduce_fun, in, dim, out_ix);
607607
}
608608

609+
/**
610+
* Execution plan for repeated reduce_over_dim_list with the same
611+
* function, input tensor, and dim_list but varying out_ix.
612+
*/
613+
class ReduceOverDimListPlan {
614+
public:
615+
ReduceOverDimListPlan(
616+
const executorch::aten::Tensor& in,
617+
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
618+
dim_list)
619+
: plan_(in, dim_list) {}
620+
621+
template <typename CTYPE, typename ReduceOp>
622+
CTYPE execute(const ReduceOp& reduce_fun, const size_t out_ix) {
623+
return plan_.execute<CTYPE, CTYPE>(
624+
[](CTYPE v) { return v; }, reduce_fun, out_ix);
625+
}
626+
627+
private:
628+
MapReduceOverDimListPlan plan_;
629+
};
630+
609631
/**
610632
* Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
611633
* for the output element at index `out_ix` using the reduce function
@@ -632,8 +654,8 @@ CTYPE reduce_over_dim_list(
632654
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
633655
dim_list,
634656
const size_t out_ix) {
635-
return map_reduce_over_dim_list<CTYPE, CTYPE>(
636-
[](CTYPE v) { return v; }, reduce_fun, in, dim_list, out_ix);
657+
ReduceOverDimListPlan plan(in, dim_list);
658+
return plan.execute<CTYPE>(reduce_fun, out_ix);
637659
}
638660

639661
//

0 commit comments

Comments
 (0)