|
14 | 14 | #include <cmath> |
15 | 15 | #include <type_traits> |
16 | 16 |
|
| 17 | +#include <ATen/cpu/vec/functional.h> |
| 18 | +#include <ATen/cpu/vec/vec.h> |
17 | 19 | #include <executorch/kernels/portable/cpu/util/activation_ops_util.h> |
18 | 20 | #include <executorch/runtime/kernel/kernel_includes.h> |
19 | 21 |
|
@@ -66,30 +68,30 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) { |
66 | 68 | } |
67 | 69 | // calculate sum and exponential in softmax dim |
68 | 70 | OUT_T temp_sum = 0; |
69 | | -#ifndef __aarch64__ |
70 | | - for (auto d = 0; d < dim_size; ++d) { |
71 | | - output_data[d * dim_stride] = |
72 | | - std::exp(input_data[d * dim_stride] - max_input); |
73 | | - temp_sum += output_data[d * dim_stride]; |
74 | | - } |
75 | | -#else |
| 71 | + using VecOut = at::vec::Vectorized<OUT_T>; |
| 72 | + using VecIn = at::vec::Vectorized<IN_T>; |
76 | 73 | auto d = 0; |
77 | | - for (; d + 4 < dim_size; d += 4) { |
| 74 | + static_assert(sizeof(IN_T) == sizeof(OUT_T)); |
| 75 | + static_assert( |
| 76 | + std::is_same_v<OUT_T, float>, |
| 77 | + "Below loop actually only supports float."); |
| 78 | + const VecIn max_input_vec(max_input); |
| 79 | + for (; d + VecOut::size() < dim_size; d += VecOut::size()) { |
78 | 80 | auto index = d * dim_stride; |
79 | | - float32x4_t in = |
80 | | - vld1q_f32(static_cast<const float*>(&input_data[index])); |
81 | | - float32x4_t out_ = |
82 | | - Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input))); |
83 | | - vst1q_f32(static_cast<float*>(&output_data[index]), out_); |
| 81 | + auto in = VecIn::loadu(&input_data[index]); |
| 82 | + auto out_ = (in - max_input_vec).exp(); |
| 83 | + out_.store(&output_data[index]); |
| 84 | +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) |
84 | 85 | temp_sum += vaddvq_f32(out_); |
| 86 | +#else |
| 87 | + temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_); |
| 88 | +#endif |
85 | 89 | } |
86 | | - |
87 | 90 | for (; d < dim_size; ++d) { |
88 | 91 | output_data[d * dim_stride] = |
89 | 92 | std::exp(input_data[d * dim_stride] - max_input); |
90 | 93 | temp_sum += output_data[d * dim_stride]; |
91 | 94 | } |
92 | | -#endif // __aarch64__ |
93 | 95 |
|
94 | 96 | temp_sum = std::log(temp_sum); |
95 | 97 |
|
|
0 commit comments