@@ -606,6 +606,28 @@ std::tuple<CTYPE, long> reduce_over_dim(
606606 [](CTYPE v) { return v; }, reduce_fun, in, dim, out_ix);
607607}
608608
609+ /* *
610+ * Execution plan for repeated reduce_over_dim_list with the same
611+ * function, input tensor, and dim_list but varying out_ix.
612+ */
613+ class ReduceOverDimListPlan {
614+ public:
615+ ReduceOverDimListPlan (
616+ const executorch::aten::Tensor& in,
617+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
618+ dim_list)
619+ : plan_(in, dim_list) {}
620+
621+ template <typename CTYPE, typename ReduceOp>
622+ CTYPE execute (const ReduceOp& reduce_fun, const size_t out_ix) {
623+ return plan_.execute <CTYPE, CTYPE>(
624+ [](CTYPE v) { return v; }, reduce_fun, out_ix);
625+ }
626+
627+ private:
628+ MapReduceOverDimListPlan plan_;
629+ };
630+
609631/* *
610632 * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
611633 * for the output element at index `out_ix` using the reduce function
@@ -632,8 +654,8 @@ CTYPE reduce_over_dim_list(
632654 const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
633655 dim_list,
634656 const size_t out_ix) {
635- return map_reduce_over_dim_list<CTYPE, CTYPE>(
636- [](CTYPE v) { return v; }, reduce_fun, in, dim_list , out_ix);
657+ ReduceOverDimListPlan plan (in, dim_list);
658+ return plan. execute <CTYPE>( reduce_fun, out_ix);
637659}
638660
639661//
0 commit comments