Skip to content

Commit 0d7821f

Browse files
authored
Expose the global openmp thread for the dask interface. (dmlc#11175)
1 parent b57840f commit 0d7821f

File tree

7 files changed

+54
-8
lines changed

7 files changed

+54
-8
lines changed

doc/parameter.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ The following parameters can be set in the global scope, using :py:func:`xgboost
3131
(compiled) with the RMM plugin enabled. Valid values are ``true`` and ``false``. See
3232
:doc:`/python/rmm-examples/index` for details.
3333

34+
* ``nthread``: Set the global number of threads for OpenMP. Use this only when you need to
35+
override some OpenMP-related environment variables like ``OMP_NUM_THREADS``. Otherwise,
36+
the ``nthread`` parameter from the Booster and the DMatrix should be preferred as the
37+
former sets the global variable and might cause conflicts with other libraries.
38+
3439
******************
3540
General Parameters
3641
******************

include/xgboost/global_config.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ namespace xgboost {
1616
struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
1717
std::int32_t verbosity{1};
1818
bool use_rmm{false};
19+
// This is not a dmlc parameter to avoid conflict with the context class.
20+
std::int32_t nthread{0};
1921
DMLC_DECLARE_PARAMETER(GlobalConfiguration) {
2022
DMLC_DECLARE_FIELD(verbosity)
2123
.set_range(0, 3)

python-package/xgboost/dask/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,7 @@ def do_train( # pylint: disable=too-many-positional-arguments
805805
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
806806

807807
local_history: TrainingCallback.EvalsLog = {}
808+
global_config.update({"nthread": n_threads})
808809

809810
with CommunicatorContext(**coll_args), config.config_context(**global_config):
810811
Xy, evals = _get_dmatrices(

src/c_api/c_api.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2014-2024, XGBoost Contributors
2+
* Copyright 2014-2025, XGBoost Contributors
33
*/
44
#include "xgboost/c_api.h"
55

@@ -143,7 +143,19 @@ XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
143143
xgboost_CHECK_C_ARG_PTR(json_str);
144144
Json config{Json::Load(StringView{json_str})};
145145

146-
for (auto& items : get<Object>(config)) {
146+
// handle nthread, it's not a dmlc parameter.
147+
auto& obj = get<Object>(config);
148+
auto it = obj.find("nthread");
149+
if (it != obj.cend()) {
150+
auto nthread = OptionalArg<Integer>(config, "nthread", Integer::Int{0});
151+
if (nthread > 0) {
152+
omp_set_num_threads(nthread);
153+
GlobalConfigThreadLocalStore::Get()->nthread = nthread;
154+
}
155+
get<Object>(config).erase("nthread");
156+
}
157+
158+
for (auto &items : obj) {
147159
switch (items.second.GetValue().Type()) {
148160
case xgboost::Value::ValueKind::kInteger: {
149161
items.second = String{std::to_string(get<Integer const>(items.second))};
@@ -183,6 +195,7 @@ XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
183195
}
184196
LOG(FATAL) << ss.str() << " }";
185197
}
198+
186199
API_END();
187200
}
188201

@@ -216,6 +229,7 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
216229
}
217230
}
218231

232+
config["nthread"] = GlobalConfigThreadLocalStore::Get()->nthread;
219233
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
220234
Json::Dump(config, &local.ret_str);
221235

src/global_config.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55
* \author Hyunsu Cho
66
*/
77

8-
#include <dmlc/thread_local.h>
98
#include "xgboost/global_config.h"
109

10+
#include <dmlc/thread_local.h>
11+
1112
namespace xgboost {
1213
DMLC_REGISTER_PARAMETER(GlobalConfiguration);
1314

14-
void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; }
15+
void InitNewThread::operator()() const {
16+
*GlobalConfigThreadLocalStore::Get() = config;
17+
if (config.nthread > 0) {
18+
omp_set_num_threads(config.nthread);
19+
}
20+
}
1521
} // namespace xgboost

tests/cpp/test_global_config.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
/**
2+
* Copyright 2020-2025, XGBoost Contributors
3+
*/
14
#include <gtest/gtest.h>
5+
#include <xgboost/c_api.h>
6+
#include <xgboost/global_config.h>
27
#include <xgboost/json.h>
38
#include <xgboost/logging.h>
4-
#include <xgboost/global_config.h>
59

610
namespace xgboost {
7-
811
TEST(GlobalConfiguration, Verbosity) {
912
// Configure verbosity via global configuration
1013
Json config{JsonObject()};
@@ -15,7 +18,7 @@ TEST(GlobalConfiguration, Verbosity) {
1518
EXPECT_EQ(ConsoleLogger::GlobalVerbosity(), ConsoleLogger::LogVerbosity::kSilent);
1619
EXPECT_NE(ConsoleLogger::LogVerbosity::kSilent, ConsoleLogger::DefaultVerbosity());
1720
// GetConfig() should also return updated verbosity
18-
Json current_config { ToJson(*GlobalConfigThreadLocalStore::Get()) };
21+
Json current_config{ToJson(*GlobalConfigThreadLocalStore::Get())};
1922
EXPECT_EQ(get<String>(current_config["verbosity"]), "0");
2023
}
2124

@@ -25,8 +28,18 @@ TEST(GlobalConfiguration, UseRMM) {
2528
auto& global_config = *GlobalConfigThreadLocalStore::Get();
2629
FromJson(config, &global_config);
2730
// GetConfig() should return updated use_rmm flag
28-
Json current_config { ToJson(*GlobalConfigThreadLocalStore::Get()) };
31+
Json current_config{ToJson(*GlobalConfigThreadLocalStore::Get())};
2932
EXPECT_EQ(get<String>(current_config["use_rmm"]), "1");
3033
}
3134

35+
TEST(GlobalConfiguration, Threads) {
36+
char const* config;
37+
ASSERT_EQ(XGBGetGlobalConfig(&config), 0);
38+
auto jconfig = Json::Load(config);
39+
auto nthread = get<Integer const>(jconfig["nthread"]);
40+
ASSERT_LE(nthread, 0);
41+
auto n_omp = omp_get_num_threads();
42+
ASSERT_EQ(XGBSetGlobalConfig(config), 0);
43+
ASSERT_EQ(n_omp, omp_get_num_threads());
44+
}
3245
} // namespace xgboost

tests/python/test_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,8 @@ def test_thread_safety():
6969

7070
for f in futures:
7171
f.result()
72+
73+
74+
def test_nthread() -> None:
75+
config = xgb.get_config()
76+
assert config["nthread"] == 0

0 commit comments

Comments
 (0)