@@ -309,7 +309,7 @@ class ApplyOverDimListPlan {
309309 dim_list,
310310 const int64_t start = 0 ,
311311 const int64_t end = -1 )
312- : in_(in) {
312+ : dim_list_(dim_list), in_(in) {
313313 ET_CHECK (check_dim_list_is_valid (in, dim_list));
314314 out_numel_ = get_out_numel (in_, dim_list);
315315 if (in.numel () == 0 ) {
@@ -352,13 +352,22 @@ class ApplyOverDimListPlan {
352352 fn,
353353 in_,
354354 is_in_dim_list_.data (),
355- get_init_index (in_, dim_list_, out_ix),
355+ get_init_index (in_, dim_list_. value () , out_ix),
356356 ustart_,
357357 uend_);
358358 return ;
359359 }
360360 }
361361
362+ const executorch::aten::Tensor& get_input_tensor () const {
363+ return in_;
364+ }
365+
366+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
367+ get_dim_list () const {
368+ return dim_list_;
369+ }
370+
362371 private:
363372 // Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
364373 size_t ustart_;
@@ -376,7 +385,7 @@ class ApplyOverDimListPlan {
376385 };
377386 ExecutionMode mode_;
378387 size_t out_numel_;
379- executorch::aten::ArrayRef<int64_t > dim_list_;
388+ executorch::aten::optional<executorch::aten:: ArrayRef<int64_t > > dim_list_;
380389 std::array<bool , kTensorDimensionLimit > is_in_dim_list_;
381390 const executorch::aten::Tensor& in_;
382391};
@@ -482,6 +491,52 @@ std::tuple<CTYPE_OUT, long> map_reduce_over_dim(
482491 return std::tuple<CTYPE_OUT, long >{acc_val, acc_ix};
483492}
484493
494+ /* *
495+ * Execution plan for repeated map_reduce_over_dim_list with the same
496+ * function, input tensor, and dim_list but varying out_ix.
497+ */
498+ class MapReduceOverDimListPlan {
499+ public:
500+ MapReduceOverDimListPlan (
501+ const executorch::aten::Tensor& in,
502+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
503+ dim_list)
504+ : plan_(in, dim_list, 1 , -1 ) {
505+ ET_CHECK_MSG (in.numel () > 0 , " Input tensor must be nonempty" );
506+ }
507+
508+ template <
509+ typename CTYPE_IN,
510+ typename CTYPE_OUT,
511+ typename MapOp,
512+ typename ReduceOp>
513+ CTYPE_OUT execute (
514+ const MapOp& map_fun,
515+ const ReduceOp& reduce_fun,
516+ const size_t out_ix) const {
517+ const size_t init_index =
518+ get_init_index (plan_.get_input_tensor (), plan_.get_dim_list (), out_ix);
519+
520+ const CTYPE_IN* const in_data =
521+ plan_.get_input_tensor ().const_data_ptr <CTYPE_IN>();
522+ CTYPE_OUT acc_val = map_fun (in_data[init_index]);
523+
524+ if (plan_.get_input_tensor ().numel () == 1 ) {
525+ return acc_val;
526+ }
527+
528+ plan_.execute (
529+ [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
530+ acc_val = reduce_fun (map_fun (in_data[in_ix]), acc_val);
531+ },
532+ out_ix);
533+ return acc_val;
534+ }
535+
536+ private:
537+ ApplyOverDimListPlan plan_;
538+ };
539+
485540/* *
486541 * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
487542 * for the output element at index `out_ix`, first applying the map `map_fun`
@@ -517,35 +572,8 @@ CTYPE_OUT map_reduce_over_dim_list(
517572 const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
518573 dim_list,
519574 const size_t out_ix) {
520- ET_CHECK (check_dim_list_is_valid (in, dim_list));
521-
522- ET_CHECK_MSG (
523- out_ix < get_out_numel (in, dim_list),
524- " Out index %zd is out of bounds" ,
525- out_ix);
526-
527- ET_CHECK_MSG (in.numel () > 0 , " Input tensor must be nonempty" );
528-
529- const size_t init_index = get_init_index (in, dim_list, out_ix);
530-
531- const CTYPE_IN* const in_data = in.const_data_ptr <CTYPE_IN>();
532- CTYPE_OUT acc_val = map_fun (in_data[init_index]);
533-
534- if (in.numel () == 1 ) {
535- return acc_val;
536- }
537-
538- apply_over_dim_list (
539- [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
540- acc_val = reduce_fun (map_fun (in_data[in_ix]), acc_val);
541- },
542- in,
543- dim_list,
544- out_ix,
545- 1 ,
546- -1 );
547-
548- return acc_val;
575+ MapReduceOverDimListPlan plan (in, dim_list);
576+ return plan.execute <CTYPE_IN, CTYPE_OUT>(map_fun, reduce_fun, out_ix);
549577}
550578
551579/* *
0 commit comments