From 3caf461c22de492b37bd8ead959dd5ee9cd5e658 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Tue, 18 Feb 2025 14:07:12 -0800 Subject: [PATCH] Use at::Vectorized in optimized log_softmax Pull Request resolved: https://github.com/pytorch/executorch/pull/8382 This should allow us to enable this op in OSS, because Vectorized handles any Sleef issues for us as needed. (I considered going straight to sharing the PyTorch core implementation, but we need parallel_for enabled for that and this improvement is easy enough to make.) Differential Revision: [D69473208](https://our.internmc.facebook.com/intern/diff/D69473208/) ghstack-source-id: 267044107 --- kernels/optimized/cpu/op_log_softmax.cpp | 32 +++++++++++++----------- kernels/optimized/cpu/targets.bzl | 13 +++------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/kernels/optimized/cpu/op_log_softmax.cpp b/kernels/optimized/cpu/op_log_softmax.cpp index c3f090a6dfe..1d2467bca5f 100644 --- a/kernels/optimized/cpu/op_log_softmax.cpp +++ b/kernels/optimized/cpu/op_log_softmax.cpp @@ -14,6 +14,8 @@ #include #include +#include +#include #include #include @@ -66,30 +68,30 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) { } // calculate sum and exponential in softmax dim OUT_T temp_sum = 0; -#ifndef __aarch64__ - for (auto d = 0; d < dim_size; ++d) { - output_data[d * dim_stride] = - std::exp(input_data[d * dim_stride] - max_input); - temp_sum += output_data[d * dim_stride]; - } -#else + using VecOut = at::vec::Vectorized; + using VecIn = at::vec::Vectorized; auto d = 0; - for (; d + 4 < dim_size; d += 4) { + static_assert(sizeof(IN_T) == sizeof(OUT_T)); + static_assert( + std::is_same_v, + "Below loop actually only supports float."); + const VecIn max_input_vec(max_input); + for (; d + VecOut::size() < dim_size; d += VecOut::size()) { auto index = d * dim_stride; - float32x4_t in = - vld1q_f32(static_cast(&input_data[index])); - float32x4_t out_ = - Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input))); - vst1q_f32(static_cast(&output_data[index]), out_); + auto in = VecIn::loadu(&input_data[index]); + auto out_ = (in - max_input_vec).exp(); + out_.store(&output_data[index]); +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) temp_sum += vaddvq_f32(out_); +#else + temp_sum += at::vec::vec_reduce_all(std::plus(), out_); +#endif } - for (; d < dim_size; ++d) { output_data[d * dim_stride] = std::exp(input_data[d * dim_stride] - max_input); temp_sum += output_data[d * dim_stride]; } -#endif // __aarch64__ temp_sum = std::log(temp_sum); diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 1c62b683b8f..9fb1f30fa9f 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -56,15 +56,10 @@ _OPTIMIZED_ATEN_OPS = ( ), op_target( name = "op_log_softmax", - deps = select({ - "DEFAULT": [ - "//executorch/kernels/portable/cpu/util:activation_ops_util", - ], - "ovr_config//cpu:arm64": [ - "//executorch/kernels/portable/cpu/util:activation_ops_util", - "fbsource//third-party/sleef:sleef_arm", - ], - }), + deps = [ + "//executorch/kernels/portable/cpu/util:activation_ops_util", + "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", + ], ), op_target( name = "op_mm",