Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions kernels/optimized/cpu/op_log_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,49 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
}

if (dim == input.dim() - 1) {
at::native::serial_vec_log_softmax_lastdim_range(
input_data_base,
output_data_base,
dim_size,
at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size),
// TODO: parallelize.
::executorch::extension::parallel_for(
0,
outer_size);
outer_size,
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
at::native::serial_vec_log_softmax_lastdim_range(
input_data_base,
output_data_base,
dim_size,
at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
executorch::extension::internal::GRAIN_SIZE,
outer_size,
dim_size),
begin,
end);
});
} else {
// BLOCK_SIZE in PyTorch is intended for server CPUs; let's
// halve it to try and have a better chance of fitting in mobile
// chip caches.
const auto [chunk_size, num_chunks] =
const auto [chunk_size_binding, num_chunks_binding] =
at::native::vec_logsoftmax_chunk_size_and_num_chunks<
float,
/*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size);
at::native::serial_vec_logsoftmax_range(
input_data_base,
output_data_base,
inner_size,
chunk_size,
num_chunks,
dim_size,
// TODO: parallelize
// Work around "capturing a structured binding is not yet supported in
// OpenMP".
const auto chunk_size = chunk_size_binding;
const auto num_chunks = num_chunks_binding;
::executorch::extension::parallel_for(
0,
outer_size * num_chunks);
outer_size * num_chunks,
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
at::native::serial_vec_logsoftmax_range(
input_data_base,
output_data_base,
inner_size,
chunk_size,
num_chunks,
dim_size,
begin,
end);
});
}
return;
}
Expand Down
Loading