Skip to content

Commit 37e4213

Browse files
committed
Update
[ghstack-poisoned]
1 parent 40a1bce commit 37e4213

File tree

1 file changed

+38
-24
lines changed

1 file changed

+38
-24
lines changed

kernels/portable/cpu/op_argmin.cpp

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1414
#include <executorch/runtime/kernel/kernel_includes.h>
15+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1516
#include <executorch/runtime/platform/assert.h>
1617

1718
namespace torch {
@@ -47,30 +48,43 @@ Tensor& argmin_out(
4748
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
4849
long* out_data = out.mutable_data_ptr<long>();
4950

50-
for (const auto out_ix : c10::irange(out.numel())) {
51-
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
52-
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
53-
// the below condition as written is equivalent to !isnan(accval) &&
54-
// (isnan(v) || v < acc_val). cases:
55-
// - if neither acc_val nor v is NaN, !(v >= acc_val) is
56-
// trivially equivalent to v < acc_val.
57-
// - if acc_val is NaN, the whole thing is trivially false.
58-
// - if acc_val is not NaN and v is NaN, then v >= acc_val
59-
// - is false because all comparisons involving NaN are
60-
// - false, so the result is true. The result is trivially
61-
// - true for the above condition that uses isnan(v) as
62-
// - well.
63-
if (!std::isnan(acc_val) && !(v >= acc_val)) {
64-
acc_val = v;
65-
acc_ix = ix;
66-
}
67-
return std::tuple<CTYPE, long>{acc_val, acc_ix};
68-
},
69-
in,
70-
dim,
71-
out_ix);
72-
out_data[out_ix] = std::get<1>(acc);
73-
}
51+
// REVIEW: this is the parallelization strategy ATen uses
52+
// specifically when the reduction is along the last dimension and
53+
// that dimension is contiguous. Is there any particular reason we
54+
// shouldn't just always use this strategy since we aren't
55+
// otherwise capable of parallelizing reductions?
56+
const auto reduction_size =
57+
dim.has_value() ? in.sizes().at(dim.value()) : in.numel();
58+
const auto grain_size = std::max(
59+
static_cast<int64_t>(1),
60+
executorch::extension::internal::GRAIN_SIZE / reduction_size);
61+
executorch::extension::parallel_for(
62+
0, out.numel(), grain_size, [&](const auto begin, const auto end) {
63+
for (const auto out_ix : c10::irange(begin, end)) {
64+
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
65+
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
66+
// the below condition as written is equivalent to
67+
// !isnan(accval) && (isnan(v) || v < acc_val). cases:
68+
// - if neither acc_val nor v is NaN, !(v >= acc_val) is
69+
// trivially equivalent to v < acc_val.
70+
// - if acc_val is NaN, the whole thing is trivially false.
71+
// - if acc_val is not NaN and v is NaN, then v >= acc_val
72+
// - is false because all comparisons involving NaN are
73+
// - false, so the result is true. The result is trivially
74+
// - true for the above condition that uses isnan(v) as
75+
// - well.
76+
if (!std::isnan(acc_val) && !(v >= acc_val)) {
77+
acc_val = v;
78+
acc_ix = ix;
79+
}
80+
return std::tuple<CTYPE, long>{acc_val, acc_ix};
81+
},
82+
in,
83+
dim,
84+
out_ix);
85+
out_data[out_ix] = std::get<1>(acc);
86+
}
87+
});
7488
});
7589

7690
return out;

0 commit comments

Comments
 (0)