File tree Expand file tree Collapse file tree 3 files changed +22
-13
lines changed Expand file tree Collapse file tree 3 files changed +22
-13
lines changed Original file line number Diff line number Diff line change 3
3
*/
4
4
#include " threading_utils.h"
5
5
6
- #include < algorithm> // for max
6
+ #include < algorithm> // for max, min
7
7
#include < exception> // for exception
8
8
#include < filesystem> // for path, exists
9
9
#include < fstream> // for ifstream
@@ -99,17 +99,18 @@ std::int32_t GetCfsCPUCount() noexcept {
99
99
return -1 ;
100
100
}
101
101
102
- std::int32_t OmpGetNumThreads (std::int32_t n_threads) {
102
+ std::int32_t OmpGetNumThreads (std::int32_t n_threads) noexcept ( true ) {
103
103
// Don't use parallel if we are in a parallel region.
104
104
if (omp_in_parallel ()) {
105
105
return 1 ;
106
106
}
107
+ // Honor the openmp thread limit, which can be set via environment variable.
108
+ auto max_n_threads = std::min ({omp_get_num_procs (), omp_get_max_threads (), OmpGetThreadLimit ()});
107
109
// If -1 or 0 is specified by the user, we default to maximum number of threads.
108
110
if (n_threads <= 0 ) {
109
- n_threads = std::min ( omp_get_num_procs (), omp_get_max_threads ()) ;
111
+ n_threads = max_n_threads ;
110
112
}
111
- // Honor the openmp thread limit, which can be set via environment variable.
112
- n_threads = std::min (n_threads, OmpGetThreadLimit ());
113
+ n_threads = std::min (n_threads, max_n_threads);
113
114
n_threads = std::max (n_threads, 1 );
114
115
return n_threads;
115
116
}
Original file line number Diff line number Diff line change @@ -257,9 +257,9 @@ inline std::int32_t OmpGetThreadLimit() {
257
257
std::int32_t GetCfsCPUCount () noexcept ;
258
258
259
259
/* *
260
- * \ brief Get the number of available threads based on n_threads specified by users.
260
+ * @ brief Get the number of available threads based on n_threads specified by users.
261
261
*/
262
- std::int32_t OmpGetNumThreads (std::int32_t n_threads);
262
+ std::int32_t OmpGetNumThreads (std::int32_t n_threads) noexcept ( true ) ;
263
263
264
264
/* !
265
265
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
Original file line number Diff line number Diff line change 1
1
/* *
2
- * Copyright 2019-2023 by XGBoost Contributors
2
+ * Copyright 2019-2024, XGBoost Contributors
3
3
*/
4
4
#include < gtest/gtest.h>
5
5
6
6
#include < cstddef> // std::size_t
7
+ #include < thread> // for std::thread
7
8
8
9
#include " ../../../src/common/threading_utils.h" // BlockedSpace2d,ParallelFor2d,ParallelFor
9
10
#include " dmlc/omp.h" // omp_in_parallel
10
11
#include " xgboost/context.h" // Context
11
12
12
- namespace xgboost {
13
- namespace common {
14
-
13
+ namespace xgboost ::common {
15
14
TEST (ParallelFor2d, CreateBlockedSpace2d) {
16
15
constexpr size_t kDim1 = 5 ;
17
16
constexpr size_t kDim2 = 3 ;
@@ -102,5 +101,14 @@ TEST(ParallelFor, Basic) {
102
101
});
103
102
ASSERT_FALSE (omp_in_parallel ());
104
103
}
105
- } // namespace common
106
- } // namespace xgboost
104
+
105
+ TEST (OmpGetNumThreads, Max) {
106
+ #if defined(_OPENMP)
107
+ auto n_threads = OmpGetNumThreads (1 << 18 );
108
+ ASSERT_LE (n_threads, std::thread::hardware_concurrency ()); // le due to container
109
+ n_threads = OmpGetNumThreads (0 );
110
+ ASSERT_GE (n_threads, 1 );
111
+ ASSERT_LE (n_threads, std::thread::hardware_concurrency ());
112
+ #endif
113
+ }
114
+ } // namespace xgboost::common
You can’t perform that action at this time.
0 commit comments