| 
14 | 14 | #include <cmath>  | 
15 | 15 | #include <type_traits>  | 
16 | 16 | 
 
  | 
 | 17 | +#include <ATen/cpu/vec/functional.h>  | 
 | 18 | +#include <ATen/cpu/vec/vec.h>  | 
17 | 19 | #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>  | 
18 | 20 | #include <executorch/runtime/kernel/kernel_includes.h>  | 
19 | 21 | 
 
  | 
@@ -66,30 +68,30 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {  | 
66 | 68 |       }  | 
67 | 69 |       // calculate sum and exponential in softmax dim  | 
68 | 70 |       OUT_T temp_sum = 0;  | 
69 |  | -#ifndef __aarch64__  | 
70 |  | -      for (auto d = 0; d < dim_size; ++d) {  | 
71 |  | -        output_data[d * dim_stride] =  | 
72 |  | -            std::exp(input_data[d * dim_stride] - max_input);  | 
73 |  | -        temp_sum += output_data[d * dim_stride];  | 
74 |  | -      }  | 
75 |  | -#else  | 
 | 71 | +      using VecOut = at::vec::Vectorized<OUT_T>;  | 
 | 72 | +      using VecIn = at::vec::Vectorized<IN_T>;  | 
76 | 73 |       auto d = 0;  | 
77 |  | -      for (; d + 4 < dim_size; d += 4) {  | 
 | 74 | +      static_assert(sizeof(IN_T) == sizeof(OUT_T));  | 
 | 75 | +      static_assert(  | 
 | 76 | +          std::is_same_v<OUT_T, float>,  | 
 | 77 | +          "Below loop actually only supports float.");  | 
 | 78 | +      const VecIn max_input_vec(max_input);  | 
 | 79 | +      for (; d + VecOut::size() < dim_size; d += VecOut::size()) {  | 
78 | 80 |         auto index = d * dim_stride;  | 
79 |  | -        float32x4_t in =  | 
80 |  | -            vld1q_f32(static_cast<const float*>(&input_data[index]));  | 
81 |  | -        float32x4_t out_ =  | 
82 |  | -            Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input)));  | 
83 |  | -        vst1q_f32(static_cast<float*>(&output_data[index]), out_);  | 
 | 81 | +        auto in = VecIn::loadu(&input_data[index]);  | 
 | 82 | +        auto out_ = (in - max_input_vec).exp();  | 
 | 83 | +        out_.store(&output_data[index]);  | 
 | 84 | +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)  | 
84 | 85 |         temp_sum += vaddvq_f32(out_);  | 
 | 86 | +#else  | 
 | 87 | +        temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);  | 
 | 88 | +#endif  | 
85 | 89 |       }  | 
86 |  | - | 
87 | 90 |       for (; d < dim_size; ++d) {  | 
88 | 91 |         output_data[d * dim_stride] =  | 
89 | 92 |             std::exp(input_data[d * dim_stride] - max_input);  | 
90 | 93 |         temp_sum += output_data[d * dim_stride];  | 
91 | 94 |       }  | 
92 |  | -#endif // __aarch64__  | 
93 | 95 | 
 
  | 
94 | 96 |       temp_sum = std::log(temp_sum);  | 
95 | 97 | 
 
  | 
 | 
0 commit comments