Skip to content

Commit 1c61752

Browse files
authored
Limit the maximum number of threads. (dmlc#10872) (dmlc#10904)
1 parent 5556bc3 commit 1c61752

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

src/common/threading_utils.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
*/
44
#include "threading_utils.h"
55

6-
#include <algorithm> // for max
6+
#include <algorithm> // for max, min
77
#include <exception> // for exception
88
#include <filesystem> // for path, exists
99
#include <fstream> // for ifstream
@@ -99,17 +99,18 @@ std::int32_t GetCfsCPUCount() noexcept {
9999
return -1;
100100
}
101101

102-
std::int32_t OmpGetNumThreads(std::int32_t n_threads) {
102+
std::int32_t OmpGetNumThreads(std::int32_t n_threads) noexcept(true) {
103103
// Don't use parallel if we are in a parallel region.
104104
if (omp_in_parallel()) {
105105
return 1;
106106
}
107+
// Honor the openmp thread limit, which can be set via environment variable.
108+
auto max_n_threads = std::min({omp_get_num_procs(), omp_get_max_threads(), OmpGetThreadLimit()});
107109
// If -1 or 0 is specified by the user, we default to maximum number of threads.
108110
if (n_threads <= 0) {
109-
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
111+
n_threads = max_n_threads;
110112
}
111-
// Honor the openmp thread limit, which can be set via environment variable.
112-
n_threads = std::min(n_threads, OmpGetThreadLimit());
113+
n_threads = std::min(n_threads, max_n_threads);
113114
n_threads = std::max(n_threads, 1);
114115
return n_threads;
115116
}

src/common/threading_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ inline std::int32_t OmpGetThreadLimit() {
257257
std::int32_t GetCfsCPUCount() noexcept;
258258

259259
/**
260-
* \brief Get the number of available threads based on n_threads specified by users.
260+
* @brief Get the number of available threads based on n_threads specified by users.
261261
*/
262-
std::int32_t OmpGetNumThreads(std::int32_t n_threads);
262+
std::int32_t OmpGetNumThreads(std::int32_t n_threads) noexcept(true);
263263

264264
/*!
265265
* \brief A C-style array with in-stack allocation. As long as the array is smaller than

tests/cpp/common/test_threading_utils.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
/**
2-
* Copyright 2019-2023 by XGBoost Contributors
2+
* Copyright 2019-2024, XGBoost Contributors
33
*/
44
#include <gtest/gtest.h>
55

66
#include <cstddef> // std::size_t
7+
#include <thread> // for std::thread
78

89
#include "../../../src/common/threading_utils.h" // BlockedSpace2d,ParallelFor2d,ParallelFor
910
#include "dmlc/omp.h" // omp_in_parallel
1011
#include "xgboost/context.h" // Context
1112

12-
namespace xgboost {
13-
namespace common {
14-
13+
namespace xgboost::common {
1514
TEST(ParallelFor2d, CreateBlockedSpace2d) {
1615
constexpr size_t kDim1 = 5;
1716
constexpr size_t kDim2 = 3;
@@ -102,5 +101,14 @@ TEST(ParallelFor, Basic) {
102101
});
103102
ASSERT_FALSE(omp_in_parallel());
104103
}
105-
} // namespace common
106-
} // namespace xgboost
104+
105+
TEST(OmpGetNumThreads, Max) {
106+
#if defined(_OPENMP)
107+
auto n_threads = OmpGetNumThreads(1 << 18);
108+
ASSERT_LE(n_threads, std::thread::hardware_concurrency()); // le due to container
109+
n_threads = OmpGetNumThreads(0);
110+
ASSERT_GE(n_threads, 1);
111+
ASSERT_LE(n_threads, std::thread::hardware_concurrency());
112+
#endif
113+
}
114+
} // namespace xgboost::common

0 commit comments

Comments
 (0)