Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 31 additions & 52 deletions kernels/optimized/cpu/op_log_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/LogSoftmaxKernelImpl.h>
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/kernel/thread_parallel_interface.h>

// `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input
// Tensor rescaling them so that the elements of the n-dimensional output
Expand Down Expand Up @@ -51,59 +53,36 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
inner_size *= input.size(i);
}

int64_t dim_stride = inner_size;
int64_t outer_stride = dim_size * dim_stride;

for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
for (size_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) {
const IN_T* input_data =
input_data_base + outer_idx * outer_stride + inner_idx;
OUT_T* output_data =
output_data_base + outer_idx * outer_stride + inner_idx;

// calculate max in softmax dim
IN_T max_input = input_data[0];
for (auto d = 0; d < dim_size; ++d) {
max_input = std::max(max_input, input_data[d * dim_stride]);
}
// calculate sum and exponential in softmax dim
OUT_T temp_sum = 0;
using VecOut = at::vec::Vectorized<OUT_T>;
using VecIn = at::vec::Vectorized<IN_T>;
auto d = 0;
static_assert(sizeof(IN_T) == sizeof(OUT_T));
static_assert(
std::is_same_v<OUT_T, float>,
"Below loop actually only supports float.");
// It is not correct to vectorize if dim is not contiguous!
if (dim_stride == 1) {
const VecIn max_input_vec(max_input);
for (; d + VecOut::size() < dim_size; d += VecOut::size()) {
auto index = d * dim_stride;
auto in = VecIn::loadu(&input_data[index]);
auto out_ = (in - max_input_vec).exp();
out_.store(&output_data[index]);
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
temp_sum += vaddvq_f32(out_);
#else
temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);
#endif
}
}
for (; d < dim_size; ++d) {
output_data[d * dim_stride] =
std::exp(input_data[d * dim_stride] - max_input);
temp_sum += output_data[d * dim_stride];
}

temp_sum = std::log(temp_sum);

for (auto dd = 0; dd < dim_size; ++dd) {
output_data[dd * dim_stride] =
input_data[dd * dim_stride] - max_input - temp_sum;
}
}
if (dim == input.dim() - 1) {
at::native::serial_vec_log_softmax_lastdim_range(
input_data_base,
output_data_base,
dim_size,
at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size),
// TODO: parallelize.
0,
outer_size);
} else {
// BLOCK_SIZE in PyTorch is intended for server CPUs; let's
// halve it to try and have a better chance of fitting in mobile
// chip caches.
const auto [chunk_size, num_chunks] =
at::native::vec_logsoftmax_chunk_size_and_num_chunks<
float,
/*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size);
at::native::serial_vec_logsoftmax_range(
input_data_base,
output_data_base,
inner_size,
chunk_size,
num_chunks,
dim_size,
// TODO: parallelize
0,
outer_size * num_chunks);
}
return;
}

// OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ OPTIMIZED_ATEN_OPS = (
op_target(
name = "op_log_softmax",
deps = [
"//executorch/extension/threadpool:threadpool",
"//executorch/kernels/portable/cpu/util:activation_ops_util",
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
],
Expand Down
Loading