Skip to content

Commit f278976

Browse files
committed
Update
[ghstack-poisoned]
1 parent 46d0580 commit f278976

File tree

1 file changed

+89
-37
lines changed

1 file changed

+89
-37
lines changed

kernels/portable/cpu/util/reduce_util.h

Lines changed: 89 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ template <typename Fn>
4545
void apply_on_flat_ix_with_dim_mask_and_base(
4646
const Fn& fn,
4747
const Tensor& in,
48-
bool* dim_mask,
48+
const bool* dim_mask,
4949
const size_t base,
5050
const size_t start,
5151
const size_t end) {
@@ -295,6 +295,92 @@ void apply_over_dim(
295295
}
296296
}
297297

298+
/**
299+
* Execution plan for repeated apply_over_dim_list with the same
300+
* function, input tensor, dim list, start, and end but varying
301+
* out_ix, as done (via {map_,}reduce_over_dim_list) in reductions.
302+
*/
303+
class ApplyOverDimListPlan {
304+
public:
305+
ApplyOverDimListPlan(
306+
const executorch::aten::Tensor& in,
307+
// If set, lifetime must last until execute() returns.
308+
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
309+
dim_list,
310+
const int64_t start = 0,
311+
const int64_t end = -1)
312+
: in_(in) {
313+
ET_CHECK(check_dim_list_is_valid(in, dim_list));
314+
out_numel_ = get_out_numel(in_, dim_list);
315+
if (in.numel() == 0) {
316+
mode_ = ExecutionMode::NothingToDo;
317+
return;
318+
}
319+
const size_t iter_length = get_reduced_dim_product(in, dim_list);
320+
const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length);
321+
const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length);
322+
ustart_ = std::max(normalized_start, size_t(0));
323+
uend_ = std::min(normalized_end, iter_length - 1);
324+
if (!dim_list.has_value() || dim_list.value().size() == 0 ||
325+
in.dim() == 0) {
326+
mode_ = ExecutionMode::NoDimMaskOrZeroDimension;
327+
return;
328+
}
329+
dim_list_ = dim_list.value();
330+
is_in_dim_list_.fill(0);
331+
for (const auto& d : dim_list.value()) {
332+
const size_t non_neg_d = d < 0 ? d + in.dim() : d;
333+
is_in_dim_list_[non_neg_d] = true;
334+
}
335+
336+
mode_ = ExecutionMode::NormalDimMask;
337+
}
338+
339+
template <typename Fn>
340+
void execute(const Fn& fn, const size_t out_ix) const {
341+
ET_CHECK_MSG(out_ix < out_numel_, "Out index %zd is out of bounds", out_ix);
342+
343+
switch (mode_) {
344+
case ExecutionMode::NothingToDo:
345+
return;
346+
case ExecutionMode::NoDimMaskOrZeroDimension:
347+
apply_on_flat_ix_with_stride_and_base(
348+
fn, /*stride=*/1, /*base=*/0, ustart_, uend_);
349+
return;
350+
case ExecutionMode::NormalDimMask:
351+
apply_on_flat_ix_with_dim_mask_and_base(
352+
fn,
353+
in_,
354+
is_in_dim_list_.data(),
355+
get_init_index(in_, dim_list_, out_ix),
356+
ustart_,
357+
uend_);
358+
return;
359+
}
360+
}
361+
362+
private:
363+
// Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
364+
size_t ustart_;
365+
// End argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
366+
size_t uend_;
367+
enum class ExecutionMode {
368+
// Empty input, no work to do.
369+
NothingToDo,
370+
// Iterate over the entire tensor with
371+
// apply_on_flat_ix_with_stride_and_base.
372+
NoDimMaskOrZeroDimension,
373+
// General mode, iterate with
374+
// apply_on_flat_ix_with_dim_mask_and_base.
375+
NormalDimMask
376+
};
377+
ExecutionMode mode_;
378+
size_t out_numel_;
379+
executorch::aten::ArrayRef<int64_t> dim_list_;
380+
std::array<bool, kTensorDimensionLimit> is_in_dim_list_;
381+
const executorch::aten::Tensor& in_;
382+
};
383+
298384
/**
299385
* Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
300386
* for the output element at index `out_ix` using the reduce function
@@ -311,42 +397,8 @@ void apply_over_dim_list(
311397
const size_t out_ix,
312398
const int64_t start = 0,
313399
const int64_t end = -1) {
314-
ET_CHECK(check_dim_list_is_valid(in, dim_list));
315-
ET_CHECK_MSG(
316-
out_ix < get_out_numel(in, dim_list),
317-
"Out index %zd is out of bounds",
318-
out_ix);
319-
320-
if (in.numel() == 0) {
321-
return;
322-
}
323-
324-
const size_t iter_length = get_reduced_dim_product(in, dim_list);
325-
const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length);
326-
const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length);
327-
const size_t ustart = std::max(normalized_start, size_t(0));
328-
const size_t uend = std::min(normalized_end, iter_length - 1);
329-
330-
// If dim_list is null or empty, or in is 0-D, iterate over the entire tensor
331-
if (!dim_list.has_value() || dim_list.value().size() == 0 || in.dim() == 0) {
332-
apply_on_flat_ix_with_stride_and_base(
333-
fn, /*stride=*/1, /*base=*/0, ustart, uend);
334-
return;
335-
}
336-
337-
// Create is_in_dims to check whether each dimension is in the dim list
338-
bool is_in_dim_list[kTensorDimensionLimit];
339-
memset(is_in_dim_list, false, sizeof(is_in_dim_list));
340-
for (const auto& d : dim_list.value()) {
341-
const size_t non_neg_d = d < 0 ? d + in.dim() : d;
342-
is_in_dim_list[non_neg_d] = true;
343-
}
344-
345-
// Compute the starting base index
346-
const size_t base = get_init_index(in, dim_list, out_ix);
347-
348-
apply_on_flat_ix_with_dim_mask_and_base(
349-
fn, in, is_in_dim_list, base, ustart, uend);
400+
ApplyOverDimListPlan plan(in, dim_list, start, end);
401+
plan.execute(fn, out_ix);
350402
}
351403

352404
//

0 commit comments

Comments
 (0)