@@ -54,33 +54,49 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
5454 }
5555
5656 if (dim == input.dim () - 1 ) {
57- at::native::serial_vec_log_softmax_lastdim_range (
58- input_data_base,
59- output_data_base,
60- dim_size,
61- at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
62- executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size),
63- // TODO: parallelize.
57+ ::executorch::extension::parallel_for (
6458 0 ,
65- outer_size);
59+ outer_size,
60+ ::executorch::extension::internal::GRAIN_SIZE,
61+ [&](const auto begin, const auto end) {
62+ at::native::serial_vec_log_softmax_lastdim_range (
63+ input_data_base,
64+ output_data_base,
65+ dim_size,
66+ at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
67+ executorch::extension::internal::GRAIN_SIZE,
68+ outer_size,
69+ dim_size),
70+ begin,
71+ end);
72+ });
6673 } else {
6774 // BLOCK_SIZE in PyTorch is intended for server CPUs; let's
6875 // halve it to try and have a better chance of fitting in mobile
6976 // chip caches.
70- const auto [chunk_size, num_chunks ] =
77+ const auto [chunk_size_binding, num_chunks_binding ] =
7178 at::native::vec_logsoftmax_chunk_size_and_num_chunks<
7279 float ,
7380 /* BLOCK_SIZE=*/ 64 * 1024 >(inner_size, dim_size);
74- at::native::serial_vec_logsoftmax_range (
75- input_data_base,
76- output_data_base,
77- inner_size,
78- chunk_size,
79- num_chunks,
80- dim_size,
81- // TODO: parallelize
81+ // Work around "capturing a structured binding is not yet supported in
82+ // OpenMP".
83+ const auto chunk_size = chunk_size_binding;
84+ const auto num_chunks = num_chunks_binding;
85+ ::executorch::extension::parallel_for (
8286 0 ,
83- outer_size * num_chunks);
87+ outer_size * num_chunks,
88+ ::executorch::extension::internal::GRAIN_SIZE,
89+ [&](const auto begin, const auto end) {
90+ at::native::serial_vec_logsoftmax_range (
91+ input_data_base,
92+ output_data_base,
93+ inner_size,
94+ chunk_size,
95+ num_chunks,
96+ dim_size,
97+ begin,
98+ end);
99+ });
84100 }
85101 return ;
86102}
0 commit comments