Skip to content

Commit 97fd520

Browse files
authored
Use lambda function in ParallelFor2D. (dmlc#9441)
1 parent 54029a5 commit 97fd520

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/common/threading_utils.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include <cstdlib> // for malloc, free
1414
#include <functional> // for function
1515
#include <new> // for bad_alloc
16-
#include <type_traits> // for is_signed, conditional_t
16+
#include <type_traits> // for is_signed, conditional_t, is_integral_v, invoke_result_t
1717
#include <vector> // for vector
1818

1919
#include "xgboost/logging.h"
@@ -87,8 +87,9 @@ class BlockedSpace2d {
8787
// dim1 - size of the first dimension in the space
8888
// getter_size_dim2 - functor to get the second dimensions for each 'row' by row-index
8989
// grain_size - max size of produced blocks
90-
BlockedSpace2d(std::size_t dim1, std::function<std::size_t(std::size_t)> getter_size_dim2,
91-
std::size_t grain_size) {
90+
template <typename Getter>
91+
BlockedSpace2d(std::size_t dim1, Getter&& getter_size_dim2, std::size_t grain_size) {
92+
static_assert(std::is_integral_v<std::invoke_result_t<Getter, std::size_t>>);
9293
for (std::size_t i = 0; i < dim1; ++i) {
9394
std::size_t size = getter_size_dim2(i);
9495
// Each row (second dim) is divided into n_blocks
@@ -137,8 +138,9 @@ class BlockedSpace2d {
137138

138139

139140
// Wrapper to implement nested parallelism with simple omp parallel for
140-
inline void ParallelFor2d(BlockedSpace2d const& space, std::int32_t n_threads,
141-
std::function<void(std::size_t, Range1d)> func) {
141+
template <typename Func>
142+
void ParallelFor2d(const BlockedSpace2d& space, int n_threads, Func&& func) {
143+
static_assert(std::is_void_v<std::invoke_result_t<Func, std::size_t, Range1d>>);
142144
std::size_t n_blocks_in_space = space.Size();
143145
CHECK_GE(n_threads, 1);
144146

0 commit comments

Comments
 (0)