Skip to content

Commit 8c5304f

Browse files
authored
Use at::Vectorized in optimized log_softmax
Differential Revision: D69473208 Pull Request resolved: #8382
1 parent 680b2de commit 8c5304f

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <cmath>
1515
#include <type_traits>
1616

17+
#include <ATen/cpu/vec/functional.h>
18+
#include <ATen/cpu/vec/vec.h>
1719
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1820
#include <executorch/runtime/kernel/kernel_includes.h>
1921

@@ -66,30 +68,30 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
6668
}
6769
// calculate sum and exponential in softmax dim
6870
OUT_T temp_sum = 0;
69-
#ifndef __aarch64__
70-
for (auto d = 0; d < dim_size; ++d) {
71-
output_data[d * dim_stride] =
72-
std::exp(input_data[d * dim_stride] - max_input);
73-
temp_sum += output_data[d * dim_stride];
74-
}
75-
#else
71+
using VecOut = at::vec::Vectorized<OUT_T>;
72+
using VecIn = at::vec::Vectorized<IN_T>;
7673
auto d = 0;
77-
for (; d + 4 < dim_size; d += 4) {
74+
static_assert(sizeof(IN_T) == sizeof(OUT_T));
75+
static_assert(
76+
std::is_same_v<OUT_T, float>,
77+
"Below loop actually only supports float.");
78+
const VecIn max_input_vec(max_input);
79+
for (; d + VecOut::size() < dim_size; d += VecOut::size()) {
7880
auto index = d * dim_stride;
79-
float32x4_t in =
80-
vld1q_f32(static_cast<const float*>(&input_data[index]));
81-
float32x4_t out_ =
82-
Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input)));
83-
vst1q_f32(static_cast<float*>(&output_data[index]), out_);
81+
auto in = VecIn::loadu(&input_data[index]);
82+
auto out_ = (in - max_input_vec).exp();
83+
out_.store(&output_data[index]);
84+
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
8485
temp_sum += vaddvq_f32(out_);
86+
#else
87+
temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);
88+
#endif
8589
}
86-
8790
for (; d < dim_size; ++d) {
8891
output_data[d * dim_stride] =
8992
std::exp(input_data[d * dim_stride] - max_input);
9093
temp_sum += output_data[d * dim_stride];
9194
}
92-
#endif // __aarch64__
9395

9496
temp_sum = std::log(temp_sum);
9597

kernels/optimized/cpu/targets.bzl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,10 @@ _OPTIMIZED_ATEN_OPS = (
5656
),
5757
op_target(
5858
name = "op_log_softmax",
59-
deps = select({
60-
"DEFAULT": [
61-
"//executorch/kernels/portable/cpu/util:activation_ops_util",
62-
],
63-
"ovr_config//cpu:arm64": [
64-
"//executorch/kernels/portable/cpu/util:activation_ops_util",
65-
"fbsource//third-party/sleef:sleef_arm",
66-
],
67-
}),
59+
deps = [
60+
"//executorch/kernels/portable/cpu/util:activation_ops_util",
61+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
62+
],
6863
),
6964
op_target(
7065
name = "op_mm",

0 commit comments

Comments
 (0)