Skip to content

Commit 423baca

Browse files
committed
Update
[ghstack-poisoned]
1 parent 22d13da commit 423baca

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

kernels/portable/cpu/util/reduce_util.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ class ApplyOverDimListPlan {
327327
return;
328328
}
329329
dim_list_ = dim_list.value();
330+
if (dim_list_.value().size() == 1) {
331+
mode_ = ExecutionMode::OnlyOneDim;
332+
return;
333+
}
330334
is_in_dim_list_.fill(0);
331335
for (const auto& d : dim_list.value()) {
332336
const size_t non_neg_d = d < 0 ? d + in.dim() : d;
@@ -347,6 +351,16 @@ class ApplyOverDimListPlan {
347351
apply_on_flat_ix_with_stride_and_base(
348352
fn, /*stride=*/1, /*base=*/0, ustart_, uend_);
349353
return;
354+
case ExecutionMode::OnlyOneDim:
355+
apply_on_flat_and_dim_ix_with_stride_and_base(
356+
[&](const auto in_ix, const auto dim_ix) {
357+
fn(in_ix);
358+
},
359+
in_.strides()[ET_NORMALIZE_IX(dim_list_.value()[0], in_.dim())],
360+
get_init_index(in_, dim_list_.value(), out_ix),
361+
ustart_,
362+
uend_);
363+
return;
350364
case ExecutionMode::NormalDimMask:
351365
apply_on_flat_ix_with_dim_mask_and_base(
352366
fn,
@@ -379,6 +393,8 @@ class ApplyOverDimListPlan {
379393
// Iterate over the entire tensor with
380394
// apply_on_flat_ix_with_stride_and_base.
381395
NoDimMaskOrZeroDimension,
396+
// dim_list has size 1, iterate with apply_on_flat_and_dim_ix_with_stride_and_base
397+
OnlyOneDim,
382398
// General mode, iterate with
383399
// apply_on_flat_ix_with_dim_mask_and_base.
384400
NormalDimMask

0 commit comments

Comments
 (0)