Skip to content

Commit 5ecb394

Browse files
authored
Merge sycl plugin support to xgboost v2.1.0 (#53)
* add tests for prediction cache * add the rest of plugin sunctionality * linting * fix compilation failure * fix dispatching * add pruner initialisation * fix fp64 support determination * make sycl updater launchable from python * apply PredictRow optimisation * fix x86 build --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent 7e94cbf commit 5ecb394

20 files changed

+1188
-141
lines changed

include/xgboost/predictor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class Predictor {
107107
*/
108108
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
109109
const gbm::GBTreeModel& model, uint32_t tree_begin,
110-
uint32_t tree_end = 0) const = 0;
110+
uint32_t tree_end = 0, bool training = false) const = 0;
111111

112112
/**
113113
* \brief Inplace prediction.

plugin/sycl/common/hist_util.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,33 @@ template void InitHist(::sycl::queue qu,
3131
GHistRow<double, MemoryType::on_device>* hist,
3232
size_t size, ::sycl::event* event);
3333

34+
/*!
35+
* \brief Copy histogram from src to dst
36+
*/
37+
template<typename GradientSumT>
38+
void CopyHist(::sycl::queue qu,
39+
GHistRow<GradientSumT, MemoryType::on_device>* dst,
40+
const GHistRow<GradientSumT, MemoryType::on_device>& src,
41+
size_t size) {
42+
GradientSumT* pdst = reinterpret_cast<GradientSumT*>(dst->Data());
43+
const GradientSumT* psrc = reinterpret_cast<const GradientSumT*>(src.DataConst());
44+
45+
qu.submit([&](::sycl::handler& cgh) {
46+
cgh.parallel_for<>(::sycl::range<1>(2 * size), [=](::sycl::item<1> pid) {
47+
const size_t i = pid.get_id(0);
48+
pdst[i] = psrc[i];
49+
});
50+
}).wait();
51+
}
52+
template void CopyHist(::sycl::queue qu,
53+
GHistRow<float, MemoryType::on_device>* dst,
54+
const GHistRow<float, MemoryType::on_device>& src,
55+
size_t size);
56+
template void CopyHist(::sycl::queue qu,
57+
GHistRow<double, MemoryType::on_device>* dst,
58+
const GHistRow<double, MemoryType::on_device>& src,
59+
size_t size);
60+
3461
/*!
3562
* \brief Compute Subtraction: dst = src1 - src2
3663
*/

plugin/sycl/common/hist_util.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ void InitHist(::sycl::queue qu,
3636
GHistRow<GradientSumT, MemoryType::on_device>* hist,
3737
size_t size, ::sycl::event* event);
3838

39+
/*!
40+
* \brief Copy histogram from src to dst
41+
*/
42+
template<typename GradientSumT>
43+
void CopyHist(::sycl::queue qu,
44+
GHistRow<GradientSumT, MemoryType::on_device>* dst,
45+
const GHistRow<GradientSumT, MemoryType::on_device>& src,
46+
size_t size);
47+
3948
/*!
4049
* \brief Compute subtraction: dst = src1 - src2
4150
*/

plugin/sycl/data.h

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,9 @@ class USMVector {
9696
~USMVector() {
9797
}
9898

99-
USMVector<T>& operator=(const USMVector<T>& other) {
100-
size_ = other.size_;
101-
capacity_ = other.capacity_;
102-
data_ = other.data_;
103-
return *this;
104-
}
99+
USMVector(const USMVector&) = delete;
100+
101+
USMVector<T>& operator=(const USMVector<T>& other) = delete;
105102

106103
T* Data() { return data_.get(); }
107104
const T* DataConst() const { return data_.get(); }
@@ -139,6 +136,17 @@ class USMVector {
139136
}
140137
}
141138

139+
/* Resize without keeping the data*/
140+
void ResizeNoCopy(::sycl::queue* qu, size_t size_new) {
141+
if (size_new <= capacity_) {
142+
size_ = size_new;
143+
} else {
144+
size_ = size_new;
145+
capacity_ = size_new;
146+
data_ = allocate_memory_(qu, size_);
147+
}
148+
}
149+
142150
void Resize(::sycl::queue* qu, size_t size_new, T v) {
143151
if (size_new <= size_) {
144152
size_ = size_new;
@@ -162,7 +170,7 @@ class USMVector {
162170
if (size_new <= size_) {
163171
size_ = size_new;
164172
} else if (size_new <= capacity_) {
165-
auto event = qu->fill(data_.get() + size_, v, size_new - size_);
173+
*event = qu->fill(data_.get() + size_, v, size_new - size_, *event);
166174
size_ = size_new;
167175
} else {
168176
size_t size_old = size_;
@@ -215,16 +223,35 @@ class USMVector {
215223

216224
/* Wrapper for DMatrix which stores all batches in a single USM buffer */
217225
struct DeviceMatrix {
218-
DMatrix* p_mat; // Pointer to the original matrix on the host
226+
DMatrix* p_mat = nullptr; // Pointer to the original matrix on the host
219227
::sycl::queue qu_;
220228
USMVector<size_t, MemoryType::on_device> row_ptr;
221229
USMVector<Entry, MemoryType::on_device> data;
222230
size_t total_offset;
231+
bool is_from_cache = false;
223232

224233
DeviceMatrix() = default;
225234

226-
void Init(::sycl::queue qu, DMatrix* dmat) {
235+
DeviceMatrix(const DeviceMatrix& other) = delete;
236+
237+
DeviceMatrix& operator= (const DeviceMatrix& other) = delete;
238+
239+
// During training the same dmatrix is used, so we don't need reload it on device
240+
bool ReinitializationRequired(DMatrix* dmat, bool training) {
241+
if (!training) return true;
242+
if (p_mat != dmat) return true;
243+
return false;
244+
}
245+
246+
void Init(::sycl::queue qu, DMatrix* dmat, bool training = false) {
227247
qu_ = qu;
248+
if (!ReinitializationRequired(dmat, training)) {
249+
is_from_cache = true;
250+
return;
251+
}
252+
253+
is_from_cache = false;
254+
228255
p_mat = dmat;
229256

230257
size_t num_row = 0;

plugin/sycl/predictor/predictor.cc

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
/*!
2-
* Copyright by Contributors 2017-2023
2+
* Copyright by Contributors 2017-2024
33
*/
4-
#pragma GCC diagnostic push
5-
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
6-
#pragma GCC diagnostic ignored "-W#pragma-messages"
7-
#pragma GCC diagnostic pop
4+
#include <dmlc/timer.h>
5+
// #pragma GCC diagnostic push
6+
// #pragma GCC diagnostic ignored "-Wtautological-constant-compare"
7+
// #pragma GCC diagnostic ignored "-W#pragma-messages"
8+
// #pragma GCC diagnostic pop
89

910
#include <cstddef>
1011
#include <limits>
@@ -158,6 +159,8 @@ float GetLeafWeight(const Node* nodes, const float* fval_buff) {
158159

159160
template <bool any_missing>
160161
void DevicePredictInternal(::sycl::queue* qu,
162+
USMVector<float, MemoryType::on_device>* fval_buff,
163+
USMVector<uint8_t, MemoryType::on_device>* miss_buff,
161164
const sycl::DeviceMatrix& dmat,
162165
HostDeviceVector<float>* out_preds,
163166
const gbm::GBTreeModel& model,
@@ -178,15 +181,17 @@ void DevicePredictInternal(::sycl::queue* qu,
178181
int num_rows = dmat.row_ptr.Size() - 1;
179182
int num_group = model.learner_model_param->num_output_group;
180183

181-
USMVector<float, MemoryType::on_device> fval_buff(qu, num_features * num_rows);
182-
USMVector<uint8_t, MemoryType::on_device> miss_buff;
183-
auto* fval_buff_ptr = fval_buff.Data();
184+
bool update_buffs = !dmat.is_from_cache;
184185

185186
std::vector<::sycl::event> events(1);
186-
if constexpr (any_missing) {
187-
miss_buff.Resize(qu, num_features * num_rows, 1, &events[0]);
187+
if (update_buffs) {
188+
fval_buff->Resize(qu, num_features * num_rows);
189+
if constexpr (any_missing) {
190+
miss_buff->Resize(qu, num_features * num_rows, 1, &events[0]);
191+
}
188192
}
189-
auto* miss_buff_ptr = miss_buff.Data();
193+
auto* fval_buff_ptr = fval_buff->Data();
194+
auto* miss_buff_ptr = miss_buff->Data();
190195

191196
auto& out_preds_vec = out_preds->HostVector();
192197
::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size());
@@ -198,12 +203,14 @@ void DevicePredictInternal(::sycl::queue* qu,
198203
auto* fval_buff_row_ptr = fval_buff_ptr + num_features * row_idx;
199204
auto* miss_buff_row_ptr = miss_buff_ptr + num_features * row_idx;
200205

201-
const Entry* first_entry = data + row_ptr[row_idx];
202-
const Entry* last_entry = data + row_ptr[row_idx + 1];
203-
for (const Entry* entry = first_entry; entry < last_entry; entry += 1) {
204-
fval_buff_row_ptr[entry->index] = entry->fvalue;
205-
if constexpr (any_missing) {
206-
miss_buff_row_ptr[entry->index] = 0;
206+
if (update_buffs) {
207+
const Entry* first_entry = data + row_ptr[row_idx];
208+
const Entry* last_entry = data + row_ptr[row_idx + 1];
209+
for (const Entry* entry = first_entry; entry < last_entry; entry += 1) {
210+
fval_buff_row_ptr[entry->index] = entry->fvalue;
211+
if constexpr (any_missing) {
212+
miss_buff_row_ptr[entry->index] = 0;
213+
}
207214
}
208215
}
209216

@@ -241,6 +248,7 @@ class Predictor : public xgboost::Predictor {
241248
void InitOutPredictions(const MetaInfo& info,
242249
HostDeviceVector<bst_float>* out_preds,
243250
const gbm::GBTreeModel& model) const override {
251+
predictor_monitor_.Start("InitOutPredictions");
244252
CHECK_NE(model.learner_model_param->num_output_group, 0);
245253
size_t n = model.learner_model_param->num_output_group * info.num_row_;
246254
const auto& base_margin = info.base_margin_.Data()->HostVector();
@@ -268,33 +276,40 @@ class Predictor : public xgboost::Predictor {
268276
}
269277
std::fill(out_preds_h.begin(), out_preds_h.end(), base_score);
270278
}
279+
predictor_monitor_.Stop("InitOutPredictions");
271280
}
272281

273282
explicit Predictor(Context const* context) :
274283
xgboost::Predictor::Predictor{context},
275-
cpu_predictor(xgboost::Predictor::Create("cpu_predictor", context)) {}
284+
cpu_predictor(xgboost::Predictor::Create("cpu_predictor", context)) {
285+
predictor_monitor_.Init("SyclPredictor");
286+
}
276287

277288
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
278289
const gbm::GBTreeModel &model, uint32_t tree_begin,
279-
uint32_t tree_end = 0) const override {
290+
uint32_t tree_end = 0, bool training = false) const override {
280291
::sycl::queue qu = device_manager.GetQueue(ctx_->Device());
281-
// TODO(razdoburdin): remove temporary workaround after cache fix
282-
sycl::DeviceMatrix device_matrix;
283-
device_matrix.Init(qu, dmat);
292+
predictor_monitor_.Start("InitDeviceMatrix");
293+
device_matrix.Init(qu, dmat, training);
294+
predictor_monitor_.Stop("InitDeviceMatrix");
284295

285296
auto* out_preds = &predts->predictions;
286297
if (tree_end == 0) {
287298
tree_end = model.trees.size();
288299
}
289300

301+
predictor_monitor_.Start("DevicePredictInternal");
290302
if (tree_begin < tree_end) {
291303
const bool any_missing = !(dmat->IsDense());
292304
if (any_missing) {
293-
DevicePredictInternal<true>(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
305+
DevicePredictInternal<true>(&qu, &fval_buff, &miss_buff, device_matrix,
306+
out_preds, model, tree_begin, tree_end);
294307
} else {
295-
DevicePredictInternal<false>(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
308+
DevicePredictInternal<false>(&qu, &fval_buff, &miss_buff, device_matrix,
309+
out_preds, model, tree_begin, tree_end);
296310
}
297311
}
312+
predictor_monitor_.Stop("DevicePredictInternal");
298313
}
299314

300315
bool InplacePredict(std::shared_ptr<DMatrix> p_m,
@@ -341,7 +356,11 @@ class Predictor : public xgboost::Predictor {
341356

342357
private:
343358
DeviceManager device_manager;
359+
mutable sycl::DeviceMatrix device_matrix;
360+
mutable USMVector<float, MemoryType::on_device> fval_buff;
361+
mutable USMVector<uint8_t, MemoryType::on_device> miss_buff;
344362

363+
mutable xgboost::common::Monitor predictor_monitor_;
345364
std::unique_ptr<xgboost::Predictor> cpu_predictor;
346365
};
347366

plugin/sycl/tree/hist_row_adder.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,41 @@ class BatchHistRowsAdder: public HistRowsAdder<GradientSumT> {
3939
}
4040
};
4141

42+
template <typename GradientSumT>
43+
class DistributedHistRowsAdder: public HistRowsAdder<GradientSumT> {
44+
public:
45+
void AddHistRows(HistUpdater<GradientSumT>* builder,
46+
std::vector<int>* sync_ids, RegTree *p_tree) override {
47+
builder->builder_monitor_.Start("AddHistRows");
48+
const size_t explicit_size = builder->nodes_for_explicit_hist_build_.size();
49+
const size_t subtaction_size = builder->nodes_for_subtraction_trick_.size();
50+
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
51+
for (size_t i = 0; i < explicit_size; ++i) {
52+
merged_node_ids[i] = builder->nodes_for_explicit_hist_build_[i].nid;
53+
}
54+
for (size_t i = 0; i < subtaction_size; ++i) {
55+
merged_node_ids[explicit_size + i] =
56+
builder->nodes_for_subtraction_trick_[i].nid;
57+
}
58+
std::sort(merged_node_ids.begin(), merged_node_ids.end());
59+
sync_ids->clear();
60+
for (auto const& nid : merged_node_ids) {
61+
if ((*p_tree)[nid].IsLeftChild()) {
62+
builder->hist_.AddHistRow(nid);
63+
builder->hist_local_worker_.AddHistRow(nid);
64+
sync_ids->push_back(nid);
65+
}
66+
}
67+
for (auto const& nid : merged_node_ids) {
68+
if (!((*p_tree)[nid].IsLeftChild())) {
69+
builder->hist_.AddHistRow(nid);
70+
builder->hist_local_worker_.AddHistRow(nid);
71+
}
72+
}
73+
builder->builder_monitor_.Stop("AddHistRows");
74+
}
75+
};
76+
4277
} // namespace tree
4378
} // namespace sycl
4479
} // namespace xgboost

plugin/sycl/tree/hist_synchronizer.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,62 @@ class BatchHistSynchronizer: public HistSynchronizer<GradientSumT> {
6161
std::vector<::sycl::event> hist_sync_events_;
6262
};
6363

64+
template <typename GradientSumT>
65+
class DistributedHistSynchronizer: public HistSynchronizer<GradientSumT> {
66+
public:
67+
void SyncHistograms(HistUpdater<GradientSumT>* builder,
68+
const std::vector<int>& sync_ids,
69+
RegTree *p_tree) override {
70+
builder->builder_monitor_.Start("SyncHistograms");
71+
const size_t nbins = builder->hist_builder_.GetNumBins();
72+
for (int node = 0; node < builder->nodes_for_explicit_hist_build_.size(); node++) {
73+
const auto entry = builder->nodes_for_explicit_hist_build_[node];
74+
auto& this_hist = builder->hist_[entry.nid];
75+
// Store posible parent node
76+
auto& this_local = builder->hist_local_worker_[entry.nid];
77+
common::CopyHist(builder->qu_, &this_local, this_hist, nbins);
78+
79+
if (!(*p_tree)[entry.nid].IsRoot()) {
80+
const size_t parent_id = (*p_tree)[entry.nid].Parent();
81+
auto sibling_nid = entry.GetSiblingId(p_tree, parent_id);
82+
auto& parent_hist = builder->hist_local_worker_[parent_id];
83+
auto& sibling_hist = builder->hist_[sibling_nid];
84+
common::SubtractionHist(builder->qu_, &sibling_hist, parent_hist,
85+
this_hist, nbins, ::sycl::event());
86+
// Store posible parent node
87+
auto& sibling_local = builder->hist_local_worker_[sibling_nid];
88+
common::CopyHist(builder->qu_, &sibling_local, sibling_hist, nbins);
89+
}
90+
}
91+
builder->ReduceHists(sync_ids, nbins);
92+
93+
ParallelSubtractionHist(builder, builder->nodes_for_explicit_hist_build_, p_tree);
94+
ParallelSubtractionHist(builder, builder->nodes_for_subtraction_trick_, p_tree);
95+
96+
builder->builder_monitor_.Stop("SyncHistograms");
97+
}
98+
99+
void ParallelSubtractionHist(HistUpdater<GradientSumT>* builder,
100+
const std::vector<ExpandEntry>& nodes,
101+
const RegTree * p_tree) {
102+
const size_t nbins = builder->hist_builder_.GetNumBins();
103+
for (int node = 0; node < nodes.size(); node++) {
104+
const auto entry = nodes[node];
105+
if (!((*p_tree)[entry.nid].IsLeftChild())) {
106+
auto& this_hist = builder->hist_[entry.nid];
107+
108+
if (!(*p_tree)[entry.nid].IsRoot()) {
109+
const size_t parent_id = (*p_tree)[entry.nid].Parent();
110+
auto& parent_hist = builder->hist_[parent_id];
111+
auto& sibling_hist = builder->hist_[entry.GetSiblingId(p_tree, parent_id)];
112+
common::SubtractionHist(builder->qu_, &this_hist, parent_hist,
113+
sibling_hist, nbins, ::sycl::event());
114+
}
115+
}
116+
}
117+
}
118+
};
119+
64120
} // namespace tree
65121
} // namespace sycl
66122
} // namespace xgboost

0 commit comments

Comments
 (0)