diff --git a/kernels/optimized/cpu/op_log_softmax.cpp b/kernels/optimized/cpu/op_log_softmax.cpp index c3f090a6dfe..1d2467bca5f 100644 --- a/kernels/optimized/cpu/op_log_softmax.cpp +++ b/kernels/optimized/cpu/op_log_softmax.cpp @@ -14,6 +14,8 @@ #include #include +#include +#include #include #include @@ -66,30 +68,30 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) { } // calculate sum and exponential in softmax dim OUT_T temp_sum = 0; -#ifndef __aarch64__ - for (auto d = 0; d < dim_size; ++d) { - output_data[d * dim_stride] = - std::exp(input_data[d * dim_stride] - max_input); - temp_sum += output_data[d * dim_stride]; - } -#else + using VecOut = at::vec::Vectorized; + using VecIn = at::vec::Vectorized; auto d = 0; - for (; d + 4 < dim_size; d += 4) { + static_assert(sizeof(IN_T) == sizeof(OUT_T)); + static_assert( + std::is_same_v, + "Below loop actually only supports float."); + const VecIn max_input_vec(max_input); + for (; d + VecOut::size() < dim_size; d += VecOut::size()) { auto index = d * dim_stride; - float32x4_t in = - vld1q_f32(static_cast(&input_data[index])); - float32x4_t out_ = - Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input))); - vst1q_f32(static_cast(&output_data[index]), out_); + auto in = VecIn::loadu(&input_data[index]); + auto out_ = (in - max_input_vec).exp(); + out_.store(&output_data[index]); +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) temp_sum += vaddvq_f32(out_); +#else + temp_sum += at::vec::vec_reduce_all(std::plus(), out_); +#endif } - for (; d < dim_size; ++d) { output_data[d * dim_stride] = std::exp(input_data[d * dim_stride] - max_input); temp_sum += output_data[d * dim_stride]; } -#endif // __aarch64__ temp_sum = std::log(temp_sum); diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 1c62b683b8f..9fb1f30fa9f 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -56,15 +56,10 @@ _OPTIMIZED_ATEN_OPS = ( ), op_target( name = "op_log_softmax", - deps = select({ - "DEFAULT": [ - "//executorch/kernels/portable/cpu/util:activation_ops_util", - ], - "ovr_config//cpu:arm64": [ - "//executorch/kernels/portable/cpu/util:activation_ops_util", - "fbsource//third-party/sleef:sleef_arm", - ], - }), + deps = [ + "//executorch/kernels/portable/cpu/util:activation_ops_util", + "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", + ], ), op_target( name = "op_mm",