@@ -45,7 +45,7 @@ template <typename Fn>
4545void apply_on_flat_ix_with_dim_mask_and_base (
4646 const Fn& fn,
4747 const Tensor& in,
48- bool * dim_mask,
48+ const bool * dim_mask,
4949 const size_t base,
5050 const size_t start,
5151 const size_t end) {
@@ -315,6 +315,92 @@ void apply_over_dim(
315315 }
316316}
317317
318+ /* *
319+ * Execution plan for repeated apply_over_dim_list with the same
320+ * function, input tensor, dim list, start, and end but varying
321+ * out_ix, as done (via {map_,}reduce_over_dim_list) in reductions.
322+ */
323+ class ApplyOverDimListPlan {
324+ public:
325+ ApplyOverDimListPlan (
326+ const executorch::aten::Tensor& in,
327+ // If set, lifetime must last until execute() returns.
328+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
329+ dim_list,
330+ const int64_t start = 0 ,
331+ const int64_t end = -1 )
332+ : in_(in) {
333+ ET_CHECK (check_dim_list_is_valid (in, dim_list));
334+ out_numel_ = get_out_numel (in_, dim_list);
335+ if (in.numel () == 0 ) {
336+ mode_ = ExecutionMode::NothingToDo;
337+ return ;
338+ }
339+ const size_t iter_length = get_reduced_dim_product (in, dim_list);
340+ const size_t normalized_start = ET_NORMALIZE_IX (start, iter_length);
341+ const size_t normalized_end = ET_NORMALIZE_IX (end, iter_length);
342+ ustart_ = std::max (normalized_start, size_t (0 ));
343+ uend_ = std::min (normalized_end, iter_length - 1 );
344+ if (!dim_list.has_value () || dim_list.value ().size () == 0 ||
345+ in.dim () == 0 ) {
346+ mode_ = ExecutionMode::NoDimMaskOrZeroDimension;
347+ return ;
348+ }
349+ dim_list_ = dim_list.value ();
350+ is_in_dim_list_.fill (0 );
351+ for (const auto & d : dim_list.value ()) {
352+ const size_t non_neg_d = d < 0 ? d + in.dim () : d;
353+ is_in_dim_list_[non_neg_d] = true ;
354+ }
355+
356+ mode_ = ExecutionMode::NormalDimMask;
357+ }
358+
359+ template <typename Fn>
360+ void execute (const Fn& fn, const size_t out_ix) const {
361+ ET_CHECK_MSG (out_ix < out_numel_, " Out index %zd is out of bounds" , out_ix);
362+
363+ switch (mode_) {
364+ case ExecutionMode::NothingToDo:
365+ return ;
366+ case ExecutionMode::NoDimMaskOrZeroDimension:
367+ apply_on_flat_ix_with_stride_and_base (
368+ fn, /* stride=*/ 1 , /* base=*/ 0 , ustart_, uend_);
369+ return ;
370+ case ExecutionMode::NormalDimMask:
371+ apply_on_flat_ix_with_dim_mask_and_base (
372+ fn,
373+ in_,
374+ is_in_dim_list_.data (),
375+ get_init_index (in_, dim_list_, out_ix),
376+ ustart_,
377+ uend_);
378+ return ;
379+ }
380+ }
381+
382+ private:
383+ // Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
384+ size_t ustart_;
385+ // End argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
386+ size_t uend_;
387+ enum class ExecutionMode {
388+ // Empty input, no work to do.
389+ NothingToDo,
390+ // Iterate over the entire tensor with
391+ // apply_on_flat_ix_with_stride_and_base.
392+ NoDimMaskOrZeroDimension,
393+ // General mode, iterate with
394+ // apply_on_flat_ix_with_dim_mask_and_base.
395+ NormalDimMask
396+ };
397+ ExecutionMode mode_;
398+ size_t out_numel_;
399+ executorch::aten::ArrayRef<int64_t > dim_list_;
400+ std::array<bool , kTensorDimensionLimit > is_in_dim_list_;
401+ const executorch::aten::Tensor& in_;
402+ };
403+
318404/* *
319405 * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
320406 * for the output element at index `out_ix` using the reduce function
@@ -331,42 +417,8 @@ void apply_over_dim_list(
331417 const size_t out_ix,
332418 const int64_t start = 0 ,
333419 const int64_t end = -1 ) {
334- ET_CHECK (check_dim_list_is_valid (in, dim_list));
335- ET_CHECK_MSG (
336- out_ix < get_out_numel (in, dim_list),
337- " Out index %zd is out of bounds" ,
338- out_ix);
339-
340- if (in.numel () == 0 ) {
341- return ;
342- }
343-
344- const size_t iter_length = get_reduced_dim_product (in, dim_list);
345- const size_t normalized_start = ET_NORMALIZE_IX (start, iter_length);
346- const size_t normalized_end = ET_NORMALIZE_IX (end, iter_length);
347- const size_t ustart = std::max (normalized_start, size_t (0 ));
348- const size_t uend = std::min (normalized_end, iter_length - 1 );
349-
350- // If dim_list is null or empty, or in is 0-D, iterate over the entire tensor
351- if (!dim_list.has_value () || dim_list.value ().size () == 0 || in.dim () == 0 ) {
352- apply_on_flat_ix_with_stride_and_base (
353- fn, /* stride=*/ 1 , /* base=*/ 0 , ustart, uend);
354- return ;
355- }
356-
357- // Create is_in_dims to check whether each dimension is in the dim list
358- bool is_in_dim_list[kTensorDimensionLimit ];
359- memset (is_in_dim_list, false , sizeof (is_in_dim_list));
360- for (const auto & d : dim_list.value ()) {
361- const size_t non_neg_d = d < 0 ? d + in.dim () : d;
362- is_in_dim_list[non_neg_d] = true ;
363- }
364-
365- // Compute the starting base index
366- const size_t base = get_init_index (in, dim_list, out_ix);
367-
368- apply_on_flat_ix_with_dim_mask_and_base (
369- fn, in, is_in_dim_list, base, ustart, uend);
420+ ApplyOverDimListPlan plan (in, dim_list, start, end);
421+ plan.execute (fn, out_ix);
370422}
371423
372424//
0 commit comments