@@ -45,7 +45,7 @@ template <typename Fn>
4545void apply_on_flat_ix_with_dim_mask_and_base (
4646 const Fn& fn,
4747 const Tensor& in,
48- bool * dim_mask,
48+ const bool * dim_mask,
4949 const size_t base,
5050 const size_t start,
5151 const size_t end) {
@@ -295,6 +295,92 @@ void apply_over_dim(
295295 }
296296}
297297
298+ /* *
299+ * Execution plan for repeated apply_over_dim_list with the same
300+ * function, input tensor, dim list, start, and end but varying
301+ * out_ix, as done (via {map_,}reduce_over_dim_list) in reductions.
302+ */
303+ class ApplyOverDimListPlan {
304+ public:
305+ ApplyOverDimListPlan (
306+ const executorch::aten::Tensor& in,
307+ // If set, lifetime must last until execute() returns.
308+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
309+ dim_list,
310+ const int64_t start = 0 ,
311+ const int64_t end = -1 )
312+ : in_(in) {
313+ ET_CHECK (check_dim_list_is_valid (in, dim_list));
314+ out_numel_ = get_out_numel (in_, dim_list);
315+ if (in.numel () == 0 ) {
316+ mode_ = ExecutionMode::NothingToDo;
317+ return ;
318+ }
319+ const size_t iter_length = get_reduced_dim_product (in, dim_list);
320+ const size_t normalized_start = ET_NORMALIZE_IX (start, iter_length);
321+ const size_t normalized_end = ET_NORMALIZE_IX (end, iter_length);
322+ ustart_ = std::max (normalized_start, size_t (0 ));
323+ uend_ = std::min (normalized_end, iter_length - 1 );
324+ if (!dim_list.has_value () || dim_list.value ().size () == 0 ||
325+ in.dim () == 0 ) {
326+ mode_ = ExecutionMode::NoDimMaskOrZeroDimension;
327+ return ;
328+ }
329+ dim_list_ = dim_list.value ();
330+ is_in_dim_list_.fill (0 );
331+ for (const auto & d : dim_list.value ()) {
332+ const size_t non_neg_d = d < 0 ? d + in.dim () : d;
333+ is_in_dim_list_[non_neg_d] = true ;
334+ }
335+
336+ mode_ = ExecutionMode::NormalDimMask;
337+ }
338+
339+ template <typename Fn>
340+ void execute (const Fn& fn, const size_t out_ix) const {
341+ ET_CHECK_MSG (out_ix < out_numel_, " Out index %zd is out of bounds" , out_ix);
342+
343+ switch (mode_) {
344+ case ExecutionMode::NothingToDo:
345+ return ;
346+ case ExecutionMode::NoDimMaskOrZeroDimension:
347+ apply_on_flat_ix_with_stride_and_base (
348+ fn, /* stride=*/ 1 , /* base=*/ 0 , ustart_, uend_);
349+ return ;
350+ case ExecutionMode::NormalDimMask:
351+ apply_on_flat_ix_with_dim_mask_and_base (
352+ fn,
353+ in_,
354+ is_in_dim_list_.data (),
355+ get_init_index (in_, dim_list_, out_ix),
356+ ustart_,
357+ uend_);
358+ return ;
359+ }
360+ }
361+
362+ private:
363+ // Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
364+ size_t ustart_;
365+ // End argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
366+ size_t uend_;
367+ enum class ExecutionMode {
368+ // Empty input, no work to do.
369+ NothingToDo,
370+ // Iterate over the entire tensor with
371+ // apply_on_flat_ix_with_stride_and_base.
372+ NoDimMaskOrZeroDimension,
373+ // General mode, iterate with
374+ // apply_on_flat_ix_with_dim_mask_and_base.
375+ NormalDimMask
376+ };
377+ ExecutionMode mode_;
378+ size_t out_numel_;
379+ executorch::aten::ArrayRef<int64_t > dim_list_;
380+ std::array<bool , kTensorDimensionLimit > is_in_dim_list_;
381+ const executorch::aten::Tensor& in_;
382+ };
383+
298384/* *
299385 * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
300386 * for the output element at index `out_ix` using the reduce function
@@ -311,42 +397,8 @@ void apply_over_dim_list(
311397 const size_t out_ix,
312398 const int64_t start = 0 ,
313399 const int64_t end = -1 ) {
314- ET_CHECK (check_dim_list_is_valid (in, dim_list));
315- ET_CHECK_MSG (
316- out_ix < get_out_numel (in, dim_list),
317- " Out index %zd is out of bounds" ,
318- out_ix);
319-
320- if (in.numel () == 0 ) {
321- return ;
322- }
323-
324- const size_t iter_length = get_reduced_dim_product (in, dim_list);
325- const size_t normalized_start = ET_NORMALIZE_IX (start, iter_length);
326- const size_t normalized_end = ET_NORMALIZE_IX (end, iter_length);
327- const size_t ustart = std::max (normalized_start, size_t (0 ));
328- const size_t uend = std::min (normalized_end, iter_length - 1 );
329-
330- // If dim_list is null or empty, or in is 0-D, iterate over the entire tensor
331- if (!dim_list.has_value () || dim_list.value ().size () == 0 || in.dim () == 0 ) {
332- apply_on_flat_ix_with_stride_and_base (
333- fn, /* stride=*/ 1 , /* base=*/ 0 , ustart, uend);
334- return ;
335- }
336-
337- // Create is_in_dims to check whether each dimension is in the dim list
338- bool is_in_dim_list[kTensorDimensionLimit ];
339- memset (is_in_dim_list, false , sizeof (is_in_dim_list));
340- for (const auto & d : dim_list.value ()) {
341- const size_t non_neg_d = d < 0 ? d + in.dim () : d;
342- is_in_dim_list[non_neg_d] = true ;
343- }
344-
345- // Compute the starting base index
346- const size_t base = get_init_index (in, dim_list, out_ix);
347-
348- apply_on_flat_ix_with_dim_mask_and_base (
349- fn, in, is_in_dim_list, base, ustart, uend);
400+ ApplyOverDimListPlan plan (in, dim_list, start, end);
401+ plan.execute (fn, out_ix);
350402}
351403
352404//
0 commit comments