Skip to content

Commit ecb1512

Browse files
committed
Use shared log_softmax kernels from PyTorch
Previous diff in stack (D76830114 / pytorch/pytorch#156243) extracted them to shared headers; now we can use them. Will send another PR to parallelize. Differential Revision: [D76830115](https://our.internmc.facebook.com/intern/diff/D76830115/) [ghstack-poisoned]
1 parent cf0bfd2 commit ecb1512

File tree

2 files changed

+32
-52
lines changed

2 files changed

+32
-52
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 31 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
#include <ATen/cpu/vec/functional.h>
1818
#include <ATen/cpu/vec/vec.h>
19+
#include <ATen/native/cpu/LogSoftmaxKernelImpl.h>
1920
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
2021
#include <executorch/runtime/kernel/kernel_includes.h>
22+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
2123

2224
// `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input
2325
// Tensor rescaling them so that the elements of the n-dimensional output
@@ -51,59 +53,36 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
5153
inner_size *= input.size(i);
5254
}
5355

54-
int64_t dim_stride = inner_size;
55-
int64_t outer_stride = dim_size * dim_stride;
56-
57-
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
58-
for (size_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) {
59-
const IN_T* input_data =
60-
input_data_base + outer_idx * outer_stride + inner_idx;
61-
OUT_T* output_data =
62-
output_data_base + outer_idx * outer_stride + inner_idx;
63-
64-
// calculate max in softmax dim
65-
IN_T max_input = input_data[0];
66-
for (auto d = 0; d < dim_size; ++d) {
67-
max_input = std::max(max_input, input_data[d * dim_stride]);
68-
}
69-
// calculate sum and exponential in softmax dim
70-
OUT_T temp_sum = 0;
71-
using VecOut = at::vec::Vectorized<OUT_T>;
72-
using VecIn = at::vec::Vectorized<IN_T>;
73-
auto d = 0;
74-
static_assert(sizeof(IN_T) == sizeof(OUT_T));
75-
static_assert(
76-
std::is_same_v<OUT_T, float>,
77-
"Below loop actually only supports float.");
78-
// It is not correct to vectorize if dim is not contiguous!
79-
if (dim_stride == 1) {
80-
const VecIn max_input_vec(max_input);
81-
for (; d + VecOut::size() < dim_size; d += VecOut::size()) {
82-
auto index = d * dim_stride;
83-
auto in = VecIn::loadu(&input_data[index]);
84-
auto out_ = (in - max_input_vec).exp();
85-
out_.store(&output_data[index]);
86-
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
87-
temp_sum += vaddvq_f32(out_);
88-
#else
89-
temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);
90-
#endif
91-
}
92-
}
93-
for (; d < dim_size; ++d) {
94-
output_data[d * dim_stride] =
95-
std::exp(input_data[d * dim_stride] - max_input);
96-
temp_sum += output_data[d * dim_stride];
97-
}
98-
99-
temp_sum = std::log(temp_sum);
100-
101-
for (auto dd = 0; dd < dim_size; ++dd) {
102-
output_data[dd * dim_stride] =
103-
input_data[dd * dim_stride] - max_input - temp_sum;
104-
}
105-
}
56+
if (dim == input.dim() - 1) {
57+
at::native::serial_vec_log_softmax_lastdim_range(
58+
input_data_base,
59+
output_data_base,
60+
dim_size,
61+
at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
62+
executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size),
63+
// TODO: parallelize.
64+
0,
65+
outer_size);
66+
} else {
67+
// BLOCK_SIZE in PyTorch is intended for server CPUs; let's
68+
// halve it to try and have a better chance of fitting in mobile
69+
// chip caches.
70+
const auto [chunk_size, num_chunks] =
71+
at::native::vec_logsoftmax_chunk_size_and_num_chunks<
72+
float,
73+
/*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size);
74+
at::native::serial_vec_logsoftmax_range(
75+
input_data_base,
76+
output_data_base,
77+
inner_size,
78+
chunk_size,
79+
num_chunks,
80+
dim_size,
81+
// TODO: parallelize
82+
0,
83+
outer_size * num_chunks);
10684
}
85+
return;
10786
}
10887

10988
// OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float

shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ OPTIMIZED_ATEN_OPS = (
230230
op_target(
231231
name = "op_log_softmax",
232232
deps = [
233+
"//executorch/extension/threadpool:threadpool",
233234
"//executorch/kernels/portable/cpu/util:activation_ops_util",
234235
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
235236
],

0 commit comments

Comments
 (0)