|
16 | 16 |
|
17 | 17 | #include <ATen/cpu/vec/functional.h> |
18 | 18 | #include <ATen/cpu/vec/vec.h> |
| 19 | +#include <ATen/native/cpu/LogSoftmaxKernelImpl.h> |
19 | 20 | #include <executorch/kernels/portable/cpu/util/activation_ops_util.h> |
20 | 21 | #include <executorch/runtime/kernel/kernel_includes.h> |
| 22 | +#include <executorch/runtime/kernel/thread_parallel_interface.h> |
21 | 23 |
|
22 | 24 | // `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input |
23 | 25 | // Tensor rescaling them so that the elements of the n-dimensional output |
@@ -51,59 +53,36 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) { |
51 | 53 | inner_size *= input.size(i); |
52 | 54 | } |
53 | 55 |
|
54 | | - int64_t dim_stride = inner_size; |
55 | | - int64_t outer_stride = dim_size * dim_stride; |
56 | | - |
57 | | - for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { |
58 | | - for (size_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) { |
59 | | - const IN_T* input_data = |
60 | | - input_data_base + outer_idx * outer_stride + inner_idx; |
61 | | - OUT_T* output_data = |
62 | | - output_data_base + outer_idx * outer_stride + inner_idx; |
63 | | - |
64 | | - // calculate max in softmax dim |
65 | | - IN_T max_input = input_data[0]; |
66 | | - for (auto d = 0; d < dim_size; ++d) { |
67 | | - max_input = std::max(max_input, input_data[d * dim_stride]); |
68 | | - } |
69 | | - // calculate sum and exponential in softmax dim |
70 | | - OUT_T temp_sum = 0; |
71 | | - using VecOut = at::vec::Vectorized<OUT_T>; |
72 | | - using VecIn = at::vec::Vectorized<IN_T>; |
73 | | - auto d = 0; |
74 | | - static_assert(sizeof(IN_T) == sizeof(OUT_T)); |
75 | | - static_assert( |
76 | | - std::is_same_v<OUT_T, float>, |
77 | | - "Below loop actually only supports float."); |
78 | | - // It is not correct to vectorize if dim is not contiguous! |
79 | | - if (dim_stride == 1) { |
80 | | - const VecIn max_input_vec(max_input); |
81 | | - for (; d + VecOut::size() < dim_size; d += VecOut::size()) { |
82 | | - auto index = d * dim_stride; |
83 | | - auto in = VecIn::loadu(&input_data[index]); |
84 | | - auto out_ = (in - max_input_vec).exp(); |
85 | | - out_.store(&output_data[index]); |
86 | | -#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) |
87 | | - temp_sum += vaddvq_f32(out_); |
88 | | -#else |
89 | | - temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_); |
90 | | -#endif |
91 | | - } |
92 | | - } |
93 | | - for (; d < dim_size; ++d) { |
94 | | - output_data[d * dim_stride] = |
95 | | - std::exp(input_data[d * dim_stride] - max_input); |
96 | | - temp_sum += output_data[d * dim_stride]; |
97 | | - } |
98 | | - |
99 | | - temp_sum = std::log(temp_sum); |
100 | | - |
101 | | - for (auto dd = 0; dd < dim_size; ++dd) { |
102 | | - output_data[dd * dim_stride] = |
103 | | - input_data[dd * dim_stride] - max_input - temp_sum; |
104 | | - } |
105 | | - } |
| 56 | + if (dim == input.dim() - 1) { |
| 57 | + at::native::serial_vec_log_softmax_lastdim_range( |
| 58 | + input_data_base, |
| 59 | + output_data_base, |
| 60 | + dim_size, |
| 61 | + at::native::vec_log_softmax_lastdim_chunk_size<IN_T>( |
| 62 | + executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size), |
| 63 | + // TODO: parallelize. |
| 64 | + 0, |
| 65 | + outer_size); |
| 66 | + } else { |
| 67 | + // BLOCK_SIZE in PyTorch is intended for server CPUs; let's |
| 68 | + // halve it to try and have a better chance of fitting in mobile |
| 69 | + // chip caches. |
| 70 | + const auto [chunk_size, num_chunks] = |
| 71 | + at::native::vec_logsoftmax_chunk_size_and_num_chunks< |
| 72 | + float, |
| 73 | + /*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size); |
| 74 | + at::native::serial_vec_logsoftmax_range( |
| 75 | + input_data_base, |
| 76 | + output_data_base, |
| 77 | + inner_size, |
| 78 | + chunk_size, |
| 79 | + num_chunks, |
| 80 | + dim_size, |
| 81 | + // TODO: parallelize |
| 82 | + 0, |
| 83 | + outer_size * num_chunks); |
106 | 84 | } |
| 85 | + return; |
107 | 86 | } |
108 | 87 |
|
109 | 88 | // OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float |
|
0 commit comments