@@ -329,7 +329,7 @@ class ApplyOverDimListPlan {
329329 dim_list,
330330 const int64_t start = 0 ,
331331 const int64_t end = -1 )
332- : in_(in) {
332+ : dim_list_(dim_list), in_(in) {
333333 ET_CHECK (check_dim_list_is_valid (in, dim_list));
334334 out_numel_ = get_out_numel (in_, dim_list);
335335 if (in.numel () == 0 ) {
@@ -372,13 +372,22 @@ class ApplyOverDimListPlan {
372372 fn,
373373 in_,
374374 is_in_dim_list_.data (),
375- get_init_index (in_, dim_list_, out_ix),
375+ get_init_index (in_, dim_list_. value () , out_ix),
376376 ustart_,
377377 uend_);
378378 return ;
379379 }
380380 }
381381
382+ const executorch::aten::Tensor& get_input_tensor () const {
383+ return in_;
384+ }
385+
386+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
387+ get_dim_list () const {
388+ return dim_list_;
389+ }
390+
382391 private:
383392 // Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
384393 size_t ustart_;
@@ -396,7 +405,7 @@ class ApplyOverDimListPlan {
396405 };
397406 ExecutionMode mode_;
398407 size_t out_numel_;
399- executorch::aten::ArrayRef<int64_t > dim_list_;
408+ executorch::aten::optional<executorch::aten:: ArrayRef<int64_t > > dim_list_;
400409 std::array<bool , kTensorDimensionLimit > is_in_dim_list_;
401410 const executorch::aten::Tensor& in_;
402411};
@@ -502,6 +511,52 @@ std::tuple<CTYPE_OUT, long> map_reduce_over_dim(
502511 return std::tuple<CTYPE_OUT, long >{acc_val, acc_ix};
503512}
504513
514+ /* *
515+ * Execution plan for repeated map_reduce_over_dim_list with the same
516+ * function, input tensor, and dim_list but varying out_ix.
517+ */
518+ class MapReduceOverDimListPlan {
519+ public:
520+ MapReduceOverDimListPlan (
521+ const executorch::aten::Tensor& in,
522+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
523+ dim_list)
524+ : plan_(in, dim_list, 1 , -1 ) {
525+ ET_CHECK_MSG (in.numel () > 0 , " Input tensor must be nonempty" );
526+ }
527+
528+ template <
529+ typename CTYPE_IN,
530+ typename CTYPE_OUT,
531+ typename MapOp,
532+ typename ReduceOp>
533+ CTYPE_OUT execute (
534+ const MapOp& map_fun,
535+ const ReduceOp& reduce_fun,
536+ const size_t out_ix) const {
537+ const size_t init_index =
538+ get_init_index (plan_.get_input_tensor (), plan_.get_dim_list (), out_ix);
539+
540+ const CTYPE_IN* const in_data =
541+ plan_.get_input_tensor ().const_data_ptr <CTYPE_IN>();
542+ CTYPE_OUT acc_val = map_fun (in_data[init_index]);
543+
544+ if (plan_.get_input_tensor ().numel () == 1 ) {
545+ return acc_val;
546+ }
547+
548+ plan_.execute (
549+ [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
550+ acc_val = reduce_fun (map_fun (in_data[in_ix]), acc_val);
551+ },
552+ out_ix);
553+ return acc_val;
554+ }
555+
556+ private:
557+ ApplyOverDimListPlan plan_;
558+ };
559+
505560/* *
506561 * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
507562 * for the output element at index `out_ix`, first applying the map `map_fun`
@@ -537,35 +592,8 @@ CTYPE_OUT map_reduce_over_dim_list(
537592 const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
538593 dim_list,
539594 const size_t out_ix) {
540- ET_CHECK (check_dim_list_is_valid (in, dim_list));
541-
542- ET_CHECK_MSG (
543- out_ix < get_out_numel (in, dim_list),
544- " Out index %zd is out of bounds" ,
545- out_ix);
546-
547- ET_CHECK_MSG (in.numel () > 0 , " Input tensor must be nonempty" );
548-
549- const size_t init_index = get_init_index (in, dim_list, out_ix);
550-
551- const CTYPE_IN* const in_data = in.const_data_ptr <CTYPE_IN>();
552- CTYPE_OUT acc_val = map_fun (in_data[init_index]);
553-
554- if (in.numel () == 1 ) {
555- return acc_val;
556- }
557-
558- apply_over_dim_list (
559- [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
560- acc_val = reduce_fun (map_fun (in_data[in_ix]), acc_val);
561- },
562- in,
563- dim_list,
564- out_ix,
565- 1 ,
566- -1 );
567-
568- return acc_val;
595+ MapReduceOverDimListPlan plan (in, dim_list);
596+ return plan.execute <CTYPE_IN, CTYPE_OUT>(map_fun, reduce_fun, out_ix);
569597}
570598
571599/* *
0 commit comments