File tree Expand file tree Collapse file tree 4 files changed +28
-4
lines changed Expand file tree Collapse file tree 4 files changed +28
-4
lines changed Original file line number Diff line number Diff line change 7
7
#include < algorithm>
8
8
#include < limits>
9
9
#include < utility>
10
+ #include < rabit/rabit.h>
10
11
11
12
#include " xgboost/parameter.h"
12
13
#include " xgboost/data.h"
@@ -80,7 +81,12 @@ class SoftmaxMultiClassObjOneAPI : public ObjFunction {
80
81
81
82
// sycl::default_selector selector;
82
83
// qu_ = sycl::queue(selector);
83
- qu_ = sycl::queue (sycl::default_selector_v);
84
+ if (rabit::IsDistributed ()) {
85
+ std::vector<sycl::device> devices = sycl::device::get_devices ();
86
+ qu_ = sycl::queue (devices[rabit::GetRank ()]);
87
+ } else {
88
+ qu_ = sycl::queue (sycl::default_selector ());
89
+ }
84
90
}
85
91
86
92
void GetGradient (const HostDeviceVector<bst_float>& preds,
Original file line number Diff line number Diff line change 4
4
#include < cstddef>
5
5
#include < limits>
6
6
#include < mutex>
7
+ #include < rabit/rabit.h>
7
8
8
9
#include " data_oneapi.h"
9
10
@@ -335,7 +336,11 @@ class GPUPredictorOneAPI : public Predictor {
335
336
} else {
336
337
// sycl::default_selector selector;
337
338
// qu_ = sycl::queue(selector);
338
- qu_ = sycl::queue (sycl::default_selector_v);
339
+ if (rabit::IsDistributed ()) {
340
+ qu_ = sycl::queue (devices[rabit::GetRank ()]);
341
+ } else {
342
+ qu_ = sycl::queue (sycl::default_selector ());
343
+ }
339
344
}
340
345
}
341
346
Original file line number Diff line number Diff line change 3
3
#include < cmath>
4
4
#include < memory>
5
5
#include < vector>
6
+ #include < rabit/rabit.h>
6
7
7
8
#include " xgboost/host_device_vector.h"
8
9
#include " xgboost/json.h"
@@ -42,7 +43,12 @@ class RegLossObjOneAPI : public ObjFunction {
42
43
43
44
// sycl::default_selector selector;
44
45
// qu_ = sycl::queue(selector);
45
- qu_ = sycl::queue (sycl::default_selector_v);
46
+ if (rabit::IsDistributed ()) {
47
+ std::vector<sycl::device> devices = sycl::device::get_devices ();
48
+ qu_ = sycl::queue (devices[rabit::GetRank ()]);
49
+ } else {
50
+ qu_ = sycl::queue (sycl::default_selector ());
51
+ }
46
52
}
47
53
48
54
void GetGradient (const HostDeviceVector<bst_float>& preds,
Original file line number Diff line number Diff line change @@ -32,6 +32,9 @@ void QuantileHistMakerOneAPI::Configure(const Args& args) {
32
32
{
33
33
LOG (INFO) << " device_id = " << i << " , name = " << devices[i].get_info <sycl::info::device::name>();
34
34
}
35
+ if (rabit::IsDistributed ()) {
36
+ LOG (INFO) << " rabit rank = " << rabit::GetRank ();
37
+ }
35
38
if (param.device_id != GenericParameter::kDefaultId ) {
36
39
int n_devices = (int )devices.size ();
37
40
CHECK_LT (param.device_id , n_devices);
@@ -73,7 +76,11 @@ void GPUQuantileHistMakerOneAPI::Configure(const Args& args) {
73
76
if (param.device_id != GenericParameter::kDefaultId ) {
74
77
qu_ = sycl::queue (devices[param.device_id ]);
75
78
} else {
76
- qu_ = sycl::queue (sycl::default_selector_v);
79
+ if (rabit::IsDistributed ()) {
80
+ qu_ = sycl::queue (devices[rabit::GetRank ()]);
81
+ } else {
82
+ qu_ = sycl::queue (sycl::default_selector ());
83
+ }
77
84
}
78
85
79
86
// initialize pruner
You can’t perform that action at this time.
0 commit comments