Skip to content

Commit a04da54

Browse files
committed
Parallelize optimized op_log_softmax
Straightforward application of parallel_for. Differential Revision: [D76831122](https://our.internmc.facebook.com/intern/diff/D76831122/) [ghstack-poisoned]
1 parent ecb1512 commit a04da54

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,33 +54,49 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
5454
}
5555

5656
if (dim == input.dim() - 1) {
57-
at::native::serial_vec_log_softmax_lastdim_range(
58-
input_data_base,
59-
output_data_base,
60-
dim_size,
61-
at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
62-
executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size),
63-
// TODO: parallelize.
57+
::executorch::extension::parallel_for(
6458
0,
65-
outer_size);
59+
outer_size,
60+
::executorch::extension::internal::GRAIN_SIZE,
61+
[&](const auto begin, const auto end) {
62+
at::native::serial_vec_log_softmax_lastdim_range(
63+
input_data_base,
64+
output_data_base,
65+
dim_size,
66+
at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
67+
executorch::extension::internal::GRAIN_SIZE,
68+
outer_size,
69+
dim_size),
70+
begin,
71+
end);
72+
});
6673
} else {
6774
// BLOCK_SIZE in PyTorch is intended for server CPUs; let's
6875
// halve it to try and have a better chance of fitting in mobile
6976
// chip caches.
70-
const auto [chunk_size, num_chunks] =
77+
const auto [chunk_size_binding, num_chunks_binding] =
7178
at::native::vec_logsoftmax_chunk_size_and_num_chunks<
7279
float,
7380
/*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size);
74-
at::native::serial_vec_logsoftmax_range(
75-
input_data_base,
76-
output_data_base,
77-
inner_size,
78-
chunk_size,
79-
num_chunks,
80-
dim_size,
81-
// TODO: parallelize
81+
// Work around "capturing a structured binding is not yet supported in
82+
// OpenMP".
83+
const auto chunk_size = chunk_size_binding;
84+
const auto num_chunks = num_chunks_binding;
85+
::executorch::extension::parallel_for(
8286
0,
83-
outer_size * num_chunks);
87+
outer_size * num_chunks,
88+
::executorch::extension::internal::GRAIN_SIZE,
89+
[&](const auto begin, const auto end) {
90+
at::native::serial_vec_logsoftmax_range(
91+
input_data_base,
92+
output_data_base,
93+
inner_size,
94+
chunk_size,
95+
num_chunks,
96+
dim_size,
97+
begin,
98+
end);
99+
});
84100
}
85101
return;
86102
}

0 commit comments

Comments
 (0)