@@ -52,33 +52,49 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
5252 }
5353
5454 if (dim == input.dim () - 1 ) {
55- at::native::serial_vec_log_softmax_lastdim_range (
56- input_data_base,
57- output_data_base,
58- dim_size,
59- at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
60- executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size),
61- // TODO: parallelize.
55+ ::executorch::extension::parallel_for (
6256 0 ,
63- outer_size);
57+ outer_size,
58+ ::executorch::extension::internal::GRAIN_SIZE,
59+ [&](const auto begin, const auto end) {
60+ at::native::serial_vec_log_softmax_lastdim_range (
61+ input_data_base,
62+ output_data_base,
63+ dim_size,
64+ at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
65+ executorch::extension::internal::GRAIN_SIZE,
66+ outer_size,
67+ dim_size),
68+ begin,
69+ end);
70+ });
6471 } else {
6572 // BLOCK_SIZE in PyTorch is intended for server CPUs; let's
6673 // halve it to try and have a better chance of fitting in mobile
6774 // chip caches.
68- const auto [chunk_size, num_chunks ] =
75+ const auto [chunk_size_binding, num_chunks_binding ] =
6976 at::native::vec_logsoftmax_chunk_size_and_num_chunks<
7077 float ,
7178 /* BLOCK_SIZE=*/ 64 * 1024 >(inner_size, dim_size);
72- at::native::serial_vec_logsoftmax_range (
73- input_data_base,
74- output_data_base,
75- inner_size,
76- chunk_size,
77- num_chunks,
78- dim_size,
79- // TODO: parallelize
79+ // Work around "capturing a structured binding is not yet supported in
80+ // OpenMP".
81+ const auto chunk_size = chunk_size_binding;
82+ const auto num_chunks = num_chunks_binding;
83+ ::executorch::extension::parallel_for (
8084 0 ,
81- outer_size * num_chunks);
85+ outer_size * num_chunks,
86+ ::executorch::extension::internal::GRAIN_SIZE,
87+ [&](const auto begin, const auto end) {
88+ at::native::serial_vec_logsoftmax_range (
89+ input_data_base,
90+ output_data_base,
91+ inner_size,
92+ chunk_size,
93+ num_chunks,
94+ dim_size,
95+ begin,
96+ end);
97+ });
8298 }
8399 return ;
84100}
0 commit comments