Skip to content

Commit 2d1ca00

Browse files
authored
Support CUDA ordinal for external memory. (dmlc#11219)
1 parent 105aa42 commit 2d1ca00

File tree

6 files changed

+86
-15
lines changed

6 files changed

+86
-15
lines changed

include/xgboost/global_config.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2020-2024, XGBoost Contributors
2+
* Copyright 2020-2025, XGBoost Contributors
33
* \file global_config.h
44
* \brief Global configuration for XGBoost
55
* \author Hyunsu Cho
@@ -31,9 +31,11 @@ struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
3131
using GlobalConfigThreadLocalStore = dmlc::ThreadLocalStore<GlobalConfiguration>;
3232

3333
struct InitNewThread {
34-
GlobalConfiguration config = *GlobalConfigThreadLocalStore::Get();
34+
GlobalConfiguration config;
35+
std::int32_t device{-1};
3536

3637
void operator()() const;
38+
InitNewThread();
3739
};
3840
} // namespace xgboost
3941

src/common/cuda_rt_utils.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2015-2024, XGBoost Contributors
2+
* Copyright 2015-2025, XGBoost Contributors
33
*/
44
#include "cuda_rt_utils.h"
55

@@ -28,9 +28,14 @@ std::int32_t AllVisibleGPUs() {
2828
return n_visgpus;
2929
}
3030

31-
std::int32_t CurrentDevice() {
32-
std::int32_t device = 0;
33-
dh::safe_cuda(cudaGetDevice(&device));
31+
std::int32_t CurrentDevice(bool raise) {
32+
std::int32_t device = -1;
33+
if (raise) {
34+
dh::safe_cuda(cudaGetDevice(&device));
35+
} else if (cudaGetDevice(&device) != cudaSuccess) {
36+
// Return -1 as an error.
37+
return -1;
38+
}
3439
return device;
3540
}
3641

@@ -100,8 +105,10 @@ void DrVersion(std::int32_t* major, std::int32_t* minor) {
100105
#else
101106
std::int32_t AllVisibleGPUs() { return 0; }
102107

103-
std::int32_t CurrentDevice() {
104-
common::AssertGPUSupport();
108+
std::int32_t CurrentDevice(bool raise) {
109+
if (raise) {
110+
common::AssertGPUSupport();
111+
}
105112
return -1;
106113
}
107114

src/common/cuda_rt_utils.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2024, XGBoost contributors
2+
* Copyright 2024-2025, XGBoost contributors
33
*/
44
#pragma once
55
#include <cstddef> // for size_t
@@ -12,7 +12,10 @@
1212
namespace xgboost::curt {
1313
std::int32_t AllVisibleGPUs();
1414

15-
std::int32_t CurrentDevice();
15+
/**
16+
* @param raise Raise error if XGBoost is not compiled with CUDA, or GPU is not available.
17+
*/
18+
std::int32_t CurrentDevice(bool raise = true);
1619

1720
// Whether the device supports coherently accessing pageable memory without calling
1821
// `cudaHostRegister` on it

src/data/ellpack_page_source.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2024, XGBoost contributors
2+
* Copyright 2019-2025, XGBoost contributors
33
*/
44
#include <algorithm> // for count_if
55
#include <cstddef> // for size_t
@@ -8,8 +8,7 @@
88
#include <numeric> // for accumulate
99
#include <utility> // for move
1010

11-
#include "../common/common.h" // for safe_cuda
12-
#include "../common/common.h" // for HumanMemUnit
11+
#include "../common/common.h" // for HumanMemUnit, safe_cuda
1312
#include "../common/cuda_rt_utils.h" // for SetDevice
1413
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
1514
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc

src/global_config.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2020-2024, XGBoost Contributors
2+
* Copyright 2020-2025, XGBoost Contributors
33
* \file global_config.cc
44
* \brief Global configuration for XGBoost
55
* \author Hyunsu Cho
@@ -9,13 +9,21 @@
99

1010
#include <dmlc/thread_local.h>
1111

12+
#include "common/cuda_rt_utils.h" // for SetDevice
13+
1214
namespace xgboost {
1315
DMLC_REGISTER_PARAMETER(GlobalConfiguration);
1416

17+
InitNewThread::InitNewThread()
18+
: config{*GlobalConfigThreadLocalStore::Get()}, device{curt::CurrentDevice(false)} {}
19+
1520
void InitNewThread::operator()() const {
1621
*GlobalConfigThreadLocalStore::Get() = config;
1722
if (config.nthread > 0) {
1823
omp_set_num_threads(config.nthread);
1924
}
25+
if (device >= 0) {
26+
curt::SetDevice(this->device);
27+
}
2028
}
2129
} // namespace xgboost

tests/test_distributed/test_gpu_with_dask/test_gpu_external_memory.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1-
"""Copyright 2024, XGBoost contributors"""
1+
"""Copyright 2024-2025, XGBoost contributors"""
2+
3+
from functools import partial, update_wrapper
4+
from typing import Any
25

36
import pytest
47
from dask_cuda import LocalCUDACluster
58
from distributed import Client
69

10+
import xgboost as xgb
11+
from xgboost import collective as coll
12+
from xgboost import testing as tm
713
from xgboost.testing.dask import check_external_memory, get_rabit_args
14+
from xgboost.tracker import RabitTracker
815

916

1017
@pytest.mark.parametrize("is_qdm", [True, False])
@@ -22,3 +29,48 @@ def test_external_memory(is_qdm: bool) -> None:
2229
is_qdm=is_qdm,
2330
)
2431
client.gather(futs)
32+
33+
34+
@pytest.mark.skipif(**tm.no_loky())
35+
def test_extmem_qdm_distributed() -> None:
36+
from loky import get_reusable_executor
37+
38+
n_samples_per_batch = 2048
39+
n_features = 128
40+
n_batches = 8
41+
42+
def do_train(ordinal: int) -> None:
43+
it = tm.IteratorForTest(
44+
*tm.make_batches(n_samples_per_batch, n_features, n_batches, use_cupy=True),
45+
cache="cache",
46+
on_host=True,
47+
)
48+
49+
Xy = xgb.ExtMemQuantileDMatrix(it)
50+
results: dict[str, Any] = {}
51+
booster = xgb.train(
52+
{"device": f"cuda:{ordinal}"},
53+
num_boost_round=2,
54+
dtrain=Xy,
55+
evals=[(Xy, "Train")],
56+
evals_result=results,
57+
)
58+
assert tm.non_increasing(results["Train"]["rmse"])
59+
60+
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=2)
61+
tracker.start()
62+
args = tracker.worker_args()
63+
64+
def local_test(worker_id: int, rabit_args: dict) -> None:
65+
import cupy as cp
66+
67+
cp.cuda.runtime.setDevice(worker_id)
68+
69+
with coll.CommunicatorContext(**rabit_args, DMLC_TASK_ID=str(worker_id)):
70+
assert coll.get_rank() == worker_id
71+
do_train(coll.get_rank())
72+
73+
n_workers = 2
74+
fn = update_wrapper(partial(local_test, rabit_args=args), local_test)
75+
with get_reusable_executor(max_workers=n_workers) as pool:
76+
results = pool.map(fn, range(n_workers))

0 commit comments

Comments
 (0)