diff --git a/kernels/optimized/cpu/op_log_softmax.cpp b/kernels/optimized/cpu/op_log_softmax.cpp index 1822a06f29f..9b6cd8700f9 100644 --- a/kernels/optimized/cpu/op_log_softmax.cpp +++ b/kernels/optimized/cpu/op_log_softmax.cpp @@ -14,10 +14,10 @@ #include #include -#include -#include +#include #include #include +#include // `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input // Tensor rescaling them so that the elements of the n-dimensional output @@ -51,59 +51,36 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) { inner_size *= input.size(i); } - int64_t dim_stride = inner_size; - int64_t outer_stride = dim_size * dim_stride; - - for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { - for (size_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) { - const IN_T* input_data = - input_data_base + outer_idx * outer_stride + inner_idx; - OUT_T* output_data = - output_data_base + outer_idx * outer_stride + inner_idx; - - // calculate max in softmax dim - IN_T max_input = input_data[0]; - for (auto d = 0; d < dim_size; ++d) { - max_input = std::max(max_input, input_data[d * dim_stride]); - } - // calculate sum and exponential in softmax dim - OUT_T temp_sum = 0; - using VecOut = at::vec::Vectorized; - using VecIn = at::vec::Vectorized; - auto d = 0; - static_assert(sizeof(IN_T) == sizeof(OUT_T)); - static_assert( - std::is_same_v, - "Below loop actually only supports float."); - // It is not correct to vectorize if dim is not contiguous! - if (dim_stride == 1) { - const VecIn max_input_vec(max_input); - for (; d + VecOut::size() < dim_size; d += VecOut::size()) { - auto index = d * dim_stride; - auto in = VecIn::loadu(&input_data[index]); - auto out_ = (in - max_input_vec).exp(); - out_.store(&output_data[index]); -#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) - temp_sum += vaddvq_f32(out_); -#else - temp_sum += at::vec::vec_reduce_all(std::plus(), out_); -#endif - } - } - for (; d < dim_size; ++d) { - output_data[d * dim_stride] = - std::exp(input_data[d * dim_stride] - max_input); - temp_sum += output_data[d * dim_stride]; - } - - temp_sum = std::log(temp_sum); - - for (auto dd = 0; dd < dim_size; ++dd) { - output_data[dd * dim_stride] = - input_data[dd * dim_stride] - max_input - temp_sum; - } - } + if (dim == input.dim() - 1) { + at::native::serial_vec_log_softmax_lastdim_range( + input_data_base, + output_data_base, + dim_size, + at::native::vec_log_softmax_lastdim_chunk_size( + executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size), + // TODO: parallelize. + 0, + outer_size); + } else { + // BLOCK_SIZE in PyTorch is intended for server CPUs; let's + // halve it to try and have a better chance of fitting in mobile + // chip caches. + const auto [chunk_size, num_chunks] = + at::native::vec_logsoftmax_chunk_size_and_num_chunks< + float, + /*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size); + at::native::serial_vec_logsoftmax_range( + input_data_base, + output_data_base, + inner_size, + chunk_size, + num_chunks, + dim_size, + // TODO: parallelize + 0, + outer_size * num_chunks); } + return; } // OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float diff --git a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl index 13acdf96d60..f2d471df9fb 100644 --- a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl @@ -230,6 +230,7 @@ OPTIMIZED_ATEN_OPS = ( op_target( name = "op_log_softmax", deps = [ + "//executorch/extension/threadpool:threadpool", "//executorch/kernels/portable/cpu/util:activation_ops_util", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ],