Skip to content

Commit 22f6d0e

Browse files
authored
initial (#2)
Co-authored-by: Dmitry Razdoburdin <>
1 parent fed039c commit 22f6d0e

File tree

4 files changed

+28
-4
lines changed

4 files changed

+28
-4
lines changed

plugin/updater_oneapi/multiclass_obj_oneapi.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <algorithm>
88
#include <limits>
99
#include <utility>
10+
#include <rabit/rabit.h>
1011

1112
#include "xgboost/parameter.h"
1213
#include "xgboost/data.h"
@@ -80,7 +81,12 @@ class SoftmaxMultiClassObjOneAPI : public ObjFunction {
8081

8182
// sycl::default_selector selector;
8283
// qu_ = sycl::queue(selector);
83-
qu_ = sycl::queue(sycl::default_selector_v);
84+
if (rabit::IsDistributed()) {
85+
std::vector<sycl::device> devices = sycl::device::get_devices();
86+
qu_ = sycl::queue(devices[rabit::GetRank()]);
87+
} else {
88+
qu_ = sycl::queue(sycl::default_selector());
89+
}
8490
}
8591

8692
void GetGradient(const HostDeviceVector<bst_float>& preds,

plugin/updater_oneapi/predictor_oneapi.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cstddef>
55
#include <limits>
66
#include <mutex>
7+
#include <rabit/rabit.h>
78

89
#include "data_oneapi.h"
910

@@ -335,7 +336,11 @@ class GPUPredictorOneAPI : public Predictor {
335336
} else {
336337
// sycl::default_selector selector;
337338
// qu_ = sycl::queue(selector);
338-
qu_ = sycl::queue(sycl::default_selector_v);
339+
if (rabit::IsDistributed()) {
340+
qu_ = sycl::queue(devices[rabit::GetRank()]);
341+
} else {
342+
qu_ = sycl::queue(sycl::default_selector());
343+
}
339344
}
340345
}
341346

plugin/updater_oneapi/regression_obj_oneapi.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <cmath>
44
#include <memory>
55
#include <vector>
6+
#include <rabit/rabit.h>
67

78
#include "xgboost/host_device_vector.h"
89
#include "xgboost/json.h"
@@ -42,7 +43,12 @@ class RegLossObjOneAPI : public ObjFunction {
4243

4344
// sycl::default_selector selector;
4445
// qu_ = sycl::queue(selector);
45-
qu_ = sycl::queue(sycl::default_selector_v);
46+
if (rabit::IsDistributed()) {
47+
std::vector<sycl::device> devices = sycl::device::get_devices();
48+
qu_ = sycl::queue(devices[rabit::GetRank()]);
49+
} else {
50+
qu_ = sycl::queue(sycl::default_selector());
51+
}
4652
}
4753

4854
void GetGradient(const HostDeviceVector<bst_float>& preds,

plugin/updater_oneapi/updater_quantile_hist_oneapi.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ void QuantileHistMakerOneAPI::Configure(const Args& args) {
3232
{
3333
LOG(INFO) << "device_id = " << i << ", name = " << devices[i].get_info<sycl::info::device::name>();
3434
}
35+
if (rabit::IsDistributed()) {
36+
LOG(INFO) << "rabit rank = " << rabit::GetRank();
37+
}
3538
if (param.device_id != GenericParameter::kDefaultId) {
3639
int n_devices = (int)devices.size();
3740
CHECK_LT(param.device_id, n_devices);
@@ -73,7 +76,11 @@ void GPUQuantileHistMakerOneAPI::Configure(const Args& args) {
7376
if (param.device_id != GenericParameter::kDefaultId) {
7477
qu_ = sycl::queue(devices[param.device_id]);
7578
} else {
76-
qu_ = sycl::queue(sycl::default_selector_v);
79+
if (rabit::IsDistributed()) {
80+
qu_ = sycl::queue(devices[rabit::GetRank()]);
81+
} else {
82+
qu_ = sycl::queue(sycl::default_selector());
83+
}
7784
}
7885

7986
// initialize pruner

0 commit comments

Comments
 (0)