diff --git a/kernels/optimized/cpu/op_log_softmax.cpp b/kernels/optimized/cpu/op_log_softmax.cpp index 9b6cd8700f9..ec05c254273 100644 --- a/kernels/optimized/cpu/op_log_softmax.cpp +++ b/kernels/optimized/cpu/op_log_softmax.cpp @@ -52,33 +52,49 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) { } if (dim == input.dim() - 1) { - at::native::serial_vec_log_softmax_lastdim_range( - input_data_base, - output_data_base, - dim_size, - at::native::vec_log_softmax_lastdim_chunk_size( - executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size), - // TODO: parallelize. + ::executorch::extension::parallel_for( 0, - outer_size); + outer_size, + ::executorch::extension::internal::GRAIN_SIZE, + [&](const auto begin, const auto end) { + at::native::serial_vec_log_softmax_lastdim_range( + input_data_base, + output_data_base, + dim_size, + at::native::vec_log_softmax_lastdim_chunk_size( + executorch::extension::internal::GRAIN_SIZE, + outer_size, + dim_size), + begin, + end); + }); } else { // BLOCK_SIZE in PyTorch is intended for server CPUs; let's // halve it to try and have a better chance of fitting in mobile // chip caches. - const auto [chunk_size, num_chunks] = + const auto [chunk_size_binding, num_chunks_binding] = at::native::vec_logsoftmax_chunk_size_and_num_chunks< float, /*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size); - at::native::serial_vec_logsoftmax_range( - input_data_base, - output_data_base, - inner_size, - chunk_size, - num_chunks, - dim_size, - // TODO: parallelize + // Work around "capturing a structured binding is not yet supported in + // OpenMP". + const auto chunk_size = chunk_size_binding; + const auto num_chunks = num_chunks_binding; + ::executorch::extension::parallel_for( 0, - outer_size * num_chunks); + outer_size * num_chunks, + ::executorch::extension::internal::GRAIN_SIZE, + [&](const auto begin, const auto end) { + at::native::serial_vec_logsoftmax_range( + input_data_base, + output_data_base, + inner_size, + chunk_size, + num_chunks, + dim_size, + begin, + end); + }); } return; }