@@ -626,6 +626,28 @@ std::tuple<CTYPE, long> reduce_over_dim(
626626 [](CTYPE v) { return v; }, reduce_fun, in, dim, out_ix);
627627}
628628
629+ /* *
630+ * Execution plan for repeated reduce_over_dim_list with the same
631+ * function, input tensor, and dim_list but varying out_ix.
632+ */
633+ class ReduceOverDimListPlan {
634+ public:
635+ ReduceOverDimListPlan (
636+ const executorch::aten::Tensor& in,
637+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
638+ dim_list)
639+ : plan_(in, dim_list) {}
640+
641+ template <typename CTYPE, typename ReduceOp>
642+ CTYPE execute (const ReduceOp& reduce_fun, const size_t out_ix) {
643+ return plan_.execute <CTYPE, CTYPE>(
644+ [](CTYPE v) { return v; }, reduce_fun, out_ix);
645+ }
646+
647+ private:
648+ MapReduceOverDimListPlan plan_;
649+ };
650+
629651/* *
630652 * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
631653 * for the output element at index `out_ix` using the reduce function
@@ -652,8 +674,8 @@ CTYPE reduce_over_dim_list(
652674 const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
653675 dim_list,
654676 const size_t out_ix) {
655- return map_reduce_over_dim_list<CTYPE, CTYPE>(
656- [](CTYPE v) { return v; }, reduce_fun, in, dim_list , out_ix);
677+ ReduceOverDimListPlan plan (in, dim_list);
678+ return plan. execute <CTYPE>( reduce_fun, out_ix);
657679}
658680
659681//
0 commit comments