Skip to content

Commit 1105543

Browse files
author
Github Executorch
committed
Use at::Vectorized in optimized log_softmax
This should allow us to enable this op in OSS, because Vectorized handles any Sleef issues for us as needed. (I considered going straight to sharing the PyTorch core implementation, but we need parallel_for enabled for that and this improvement is easy enough to make.) Differential Revision: [D69473208](https://our.internmc.facebook.com/intern/diff/D69473208/) [ghstack-poisoned]
1 parent 8d96d74 commit 1105543

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <cmath>
1515
#include <type_traits>
1616

17+
#include <ATen/cpu/vec/functional.h>
18+
#include <ATen/cpu/vec/vec.h>
1719
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1820
#include <executorch/runtime/kernel/kernel_includes.h>
1921

@@ -66,30 +68,30 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
6668
}
6769
// calculate sum and exponential in softmax dim
6870
OUT_T temp_sum = 0;
69-
#ifndef __aarch64__
70-
for (auto d = 0; d < dim_size; ++d) {
71-
output_data[d * dim_stride] =
72-
std::exp(input_data[d * dim_stride] - max_input);
73-
temp_sum += output_data[d * dim_stride];
74-
}
75-
#else
71+
using VecOut = at::vec::Vectorized<OUT_T>;
72+
using VecIn = at::vec::Vectorized<IN_T>;
7673
auto d = 0;
77-
for (; d + 4 < dim_size; d += 4) {
74+
static_assert(sizeof(IN_T) == sizeof(OUT_T));
75+
static_assert(
76+
std::is_same_v<OUT_T, float>,
77+
"Below loop actually only supports float.");
78+
const VecIn max_input_vec(max_input);
79+
for (; d + VecOut::size() < dim_size; d += VecOut::size()) {
7880
auto index = d * dim_stride;
79-
float32x4_t in =
80-
vld1q_f32(static_cast<const float*>(&input_data[index]));
81-
float32x4_t out_ =
82-
Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input)));
83-
vst1q_f32(static_cast<float*>(&output_data[index]), out_);
81+
auto in = VecIn::loadu(&input_data[index]);
82+
auto out_ = (in - max_input_vec).exp();
83+
out_.store(&output_data[index]);
84+
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
8485
temp_sum += vaddvq_f32(out_);
86+
#else
87+
temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);
88+
#endif
8589
}
86-
8790
for (; d < dim_size; ++d) {
8891
output_data[d * dim_stride] =
8992
std::exp(input_data[d * dim_stride] - max_input);
9093
temp_sum += output_data[d * dim_stride];
9194
}
92-
#endif // __aarch64__
9395

9496
temp_sum = std::log(temp_sum);
9597

kernels/optimized/cpu/targets.bzl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,10 @@ _OPTIMIZED_ATEN_OPS = (
5656
),
5757
op_target(
5858
name = "op_log_softmax",
59-
deps = select({
60-
"DEFAULT": [
61-
"//executorch/kernels/portable/cpu/util:activation_ops_util",
62-
],
63-
"ovr_config//cpu:arm64": [
64-
"//executorch/kernels/portable/cpu/util:activation_ops_util",
65-
"fbsource//third-party/sleef:sleef_arm",
66-
],
67-
}),
59+
deps = [
60+
"//executorch/kernels/portable/cpu/util:activation_ops_util",
61+
"//executorch/runtime/core/portable_type/c10:aten_headers_for_executorch",
62+
],
6863
),
6964
op_target(
7065
name = "op_mm",

0 commit comments

Comments
 (0)