Skip to content

Commit 24d225c

Browse files
authored
[SYCL] Implement UpdatePredictionCache and connect updater with leraner. (dmlc#10701)
--------- Co-authored-by: Dmitry Razdoburdin <>
1 parent 9b88495 commit 24d225c

11 files changed

+502
-126
lines changed

plugin/sycl/tree/hist_updater.cc

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,99 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
307307
builder_monitor_.Stop("ExpandWithLossGuide");
308308
}
309309

310+
template <typename GradientSumT>
311+
void HistUpdater<GradientSumT>::Update(
312+
xgboost::tree::TrainParam const *param,
313+
const common::GHistIndexMatrix &gmat,
314+
const USMVector<GradientPair, MemoryType::on_device>& gpair,
315+
DMatrix *p_fmat,
316+
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
317+
RegTree *p_tree) {
318+
builder_monitor_.Start("Update");
319+
320+
tree_evaluator_.Reset(qu_, param_, p_fmat->Info().num_col_);
321+
interaction_constraints_.Reset();
322+
323+
this->InitData(gmat, gpair, *p_fmat, *p_tree);
324+
if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) {
325+
ExpandWithLossGuide(gmat, p_tree, gpair);
326+
} else {
327+
ExpandWithDepthWise(gmat, p_tree, gpair);
328+
}
329+
330+
for (int nid = 0; nid < p_tree->NumNodes(); ++nid) {
331+
p_tree->Stat(nid).loss_chg = snode_host_[nid].best.loss_chg;
332+
p_tree->Stat(nid).base_weight = snode_host_[nid].weight;
333+
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_host_[nid].stats.GetHess());
334+
}
335+
336+
builder_monitor_.Stop("Update");
337+
}
338+
339+
template<typename GradientSumT>
340+
bool HistUpdater<GradientSumT>::UpdatePredictionCache(
341+
const DMatrix* data,
342+
linalg::MatrixView<float> out_preds) {
343+
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
344+
// conjunction with Update().
345+
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
346+
return false;
347+
}
348+
builder_monitor_.Start("UpdatePredictionCache");
349+
CHECK_GT(out_preds.Size(), 0U);
350+
351+
const size_t stride = out_preds.Stride(0);
352+
const bool is_first_group = (out_pred_ptr == nullptr);
353+
const size_t gid = out_pred_ptr == nullptr ? 0 : &out_preds(0) - out_pred_ptr;
354+
const bool is_last_group = (gid + 1 == stride);
355+
356+
const int buffer_size = out_preds.Size() *stride;
357+
if (buffer_size == 0) return true;
358+
359+
::sycl::event event;
360+
if (is_first_group) {
361+
out_preds_buf_.ResizeNoCopy(&qu_, buffer_size);
362+
out_pred_ptr = &out_preds(0);
363+
event = qu_.memcpy(out_preds_buf_.Data(), out_pred_ptr, buffer_size * sizeof(bst_float), event);
364+
}
365+
auto* out_preds_buf_ptr = out_preds_buf_.Data();
366+
367+
size_t n_nodes = row_set_collection_.Size();
368+
std::vector<::sycl::event> events(n_nodes);
369+
for (size_t node = 0; node < n_nodes; node++) {
370+
const common::RowSetCollection::Elem& rowset = row_set_collection_[node];
371+
if (rowset.begin != nullptr && rowset.end != nullptr && rowset.Size() != 0) {
372+
int nid = rowset.node_id;
373+
// if a node is marked as deleted by the pruner, traverse upward to locate
374+
// a non-deleted leaf.
375+
if ((*p_last_tree_)[nid].IsDeleted()) {
376+
while ((*p_last_tree_)[nid].IsDeleted()) {
377+
nid = (*p_last_tree_)[nid].Parent();
378+
}
379+
CHECK((*p_last_tree_)[nid].IsLeaf());
380+
}
381+
bst_float leaf_value = (*p_last_tree_)[nid].LeafValue();
382+
const size_t* rid = rowset.begin;
383+
const size_t num_rows = rowset.Size();
384+
385+
events[node] = qu_.submit([&](::sycl::handler& cgh) {
386+
cgh.depends_on(event);
387+
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) {
388+
out_preds_buf_ptr[rid[pid.get_id(0)]*stride + gid] += leaf_value;
389+
});
390+
});
391+
}
392+
}
393+
if (is_last_group) {
394+
qu_.memcpy(out_pred_ptr, out_preds_buf_ptr, buffer_size * sizeof(bst_float), events);
395+
out_pred_ptr = nullptr;
396+
}
397+
qu_.wait();
398+
399+
builder_monitor_.Stop("UpdatePredictionCache");
400+
return true;
401+
}
402+
310403
template<typename GradientSumT>
311404
void HistUpdater<GradientSumT>::InitSampling(
312405
const USMVector<GradientPair, MemoryType::on_device> &gpair,
@@ -479,6 +572,8 @@ void HistUpdater<GradientSumT>::InitData(
479572
}
480573
}
481574

575+
// store a pointer to the tree
576+
p_last_tree_ = &tree;
482577
column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(),
483578
param_.colsample_bynode, param_.colsample_bylevel,
484579
param_.colsample_bytree);

plugin/sycl/tree/hist_updater.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
#include <xgboost/tree_updater.h>
1212
#pragma GCC diagnostic pop
1313

14-
#include <utility>
1514
#include <vector>
1615
#include <memory>
1716
#include <queue>
17+
#include <utility>
1818

1919
#include "../common/partition_builder.h"
2020
#include "split_evaluator.h"
@@ -54,12 +54,10 @@ class HistUpdater {
5454
explicit HistUpdater(const Context* ctx,
5555
::sycl::queue qu,
5656
const xgboost::tree::TrainParam& param,
57-
std::unique_ptr<TreeUpdater> pruner,
5857
FeatureInteractionConstraintHost int_constraints_,
5958
DMatrix const* fmat)
6059
: ctx_(ctx), qu_(qu), param_(param),
6160
tree_evaluator_(qu, param, fmat->Info().num_col_),
62-
pruner_(std::move(pruner)),
6361
interaction_constraints_{std::move(int_constraints_)},
6462
p_last_tree_(nullptr), p_last_fmat_(fmat) {
6563
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
@@ -73,6 +71,17 @@ class HistUpdater {
7371
sub_group_size_ = sub_group_sizes.back();
7472
}
7573

74+
// update one tree, growing
75+
void Update(xgboost::tree::TrainParam const *param,
76+
const common::GHistIndexMatrix &gmat,
77+
const USMVector<GradientPair, MemoryType::on_device>& gpair,
78+
DMatrix *p_fmat,
79+
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
80+
RegTree *p_tree);
81+
82+
bool UpdatePredictionCache(const DMatrix* data,
83+
linalg::MatrixView<float> p_out_preds);
84+
7685
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
7786
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
7887

@@ -200,7 +209,6 @@ class HistUpdater {
200209
std::vector<SplitEntry<GradientSumT>> best_splits_host_;
201210

202211
TreeEvaluator<GradientSumT> tree_evaluator_;
203-
std::unique_ptr<TreeUpdater> pruner_;
204212
FeatureInteractionConstraintHost interaction_constraints_;
205213

206214
// back pointers to tree and data matrix
@@ -247,6 +255,9 @@ class HistUpdater {
247255
std::unique_ptr<HistSynchronizer<GradientSumT>> hist_synchronizer_;
248256
std::unique_ptr<HistRowsAdder<GradientSumT>> hist_rows_adder_;
249257

258+
USMVector<bst_float, MemoryType::on_device> out_preds_buf_;
259+
bst_float* out_pred_ptr = nullptr;
260+
250261
::sycl::queue qu_;
251262
};
252263

plugin/sycl/tree/updater_quantile_hist.cc

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
* \file updater_quantile_hist.cc
44
*/
55
#include <vector>
6+
#include <memory>
67

78
#pragma GCC diagnostic push
89
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
@@ -29,19 +30,106 @@ void QuantileHistMaker::Configure(const Args& args) {
2930

3031
param_.UpdateAllowUnknown(args);
3132
hist_maker_param_.UpdateAllowUnknown(args);
33+
34+
bool has_fp64_support = qu_.get_device().has(::sycl::aspect::fp64);
35+
if (hist_maker_param_.single_precision_histogram || !has_fp64_support) {
36+
if (!hist_maker_param_.single_precision_histogram) {
37+
LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True";
38+
}
39+
hist_precision_ = HistPrecision::fp32;
40+
} else {
41+
hist_precision_ = HistPrecision::fp64;
42+
}
43+
}
44+
45+
template<typename GradientSumT>
46+
void QuantileHistMaker::SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>* pimpl,
47+
DMatrix *dmat) {
48+
pimpl->reset(new HistUpdater<GradientSumT>(
49+
ctx_,
50+
qu_,
51+
param_,
52+
int_constraint_, dmat));
53+
if (collective::IsDistributed()) {
54+
LOG(FATAL) << "Distributed mode is not yet upstreamed for sycl";
55+
} else {
56+
(*pimpl)->SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
57+
(*pimpl)->SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
58+
}
59+
}
60+
61+
template<typename GradientSumT>
62+
void QuantileHistMaker::CallUpdate(
63+
const std::unique_ptr<HistUpdater<GradientSumT>>& pimpl,
64+
xgboost::tree::TrainParam const *param,
65+
linalg::Matrix<GradientPair> *gpair,
66+
DMatrix *dmat,
67+
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
68+
const std::vector<RegTree *> &trees) {
69+
const auto* gpair_h = gpair->Data();
70+
gpair_device_.Resize(&qu_, gpair_h->Size());
71+
qu_.memcpy(gpair_device_.Data(), gpair_h->HostPointer(), gpair_h->Size() * sizeof(GradientPair));
72+
qu_.wait();
73+
74+
for (auto tree : trees) {
75+
pimpl->Update(param, gmat_, gpair_device_, dmat, out_position, tree);
76+
}
3277
}
3378

3479
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
3580
linalg::Matrix<GradientPair>* gpair,
3681
DMatrix *dmat,
3782
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
3883
const std::vector<RegTree *> &trees) {
39-
LOG(FATAL) << "Not Implemented yet";
84+
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
85+
updater_monitor_.Start("DeviceMatrixInitialization");
86+
sycl::DeviceMatrix dmat_device;
87+
dmat_device.Init(qu_, dmat);
88+
updater_monitor_.Stop("DeviceMatrixInitialization");
89+
updater_monitor_.Start("GmatInitialization");
90+
gmat_.Init(qu_, ctx_, dmat_device, static_cast<uint32_t>(param_.max_bin));
91+
updater_monitor_.Stop("GmatInitialization");
92+
is_gmat_initialized_ = true;
93+
}
94+
// rescale learning rate according to size of trees
95+
float lr = param_.learning_rate;
96+
param_.learning_rate = lr / trees.size();
97+
int_constraint_.Configure(param_, dmat->Info().num_col_);
98+
// build tree
99+
if (hist_precision_ == HistPrecision::fp32) {
100+
if (!pimpl_fp32) {
101+
SetPimpl(&pimpl_fp32, dmat);
102+
}
103+
CallUpdate(pimpl_fp32, param, gpair, dmat, out_position, trees);
104+
} else {
105+
if (!pimpl_fp64) {
106+
SetPimpl(&pimpl_fp64, dmat);
107+
}
108+
CallUpdate(pimpl_fp64, param, gpair, dmat, out_position, trees);
109+
}
110+
111+
param_.learning_rate = lr;
112+
113+
p_last_dmat_ = dmat;
40114
}
41115

42116
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
43117
linalg::MatrixView<float> out_preds) {
44-
LOG(FATAL) << "Not Implemented yet";
118+
if (param_.subsample < 1.0f) return false;
119+
120+
if (hist_precision_ == HistPrecision::fp32) {
121+
if (pimpl_fp32) {
122+
return pimpl_fp32->UpdatePredictionCache(data, out_preds);
123+
} else {
124+
return false;
125+
}
126+
} else {
127+
if (pimpl_fp64) {
128+
return pimpl_fp64->UpdatePredictionCache(data, out_preds);
129+
} else {
130+
return false;
131+
}
132+
}
45133
}
46134

47135
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")

plugin/sycl/tree/updater_quantile_hist.h

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
#include <xgboost/tree_updater.h>
1010

1111
#include <vector>
12+
#include <memory>
1213

1314
#include "../data/gradient_index.h"
1415
#include "../common/hist_util.h"
1516
#include "../common/row_set.h"
1617
#include "../common/partition_builder.h"
1718
#include "split_evaluator.h"
1819
#include "../device_manager.h"
19-
20+
#include "hist_updater.h"
2021
#include "xgboost/data.h"
22+
2123
#include "xgboost/json.h"
2224
#include "../../src/tree/constraints.h"
2325
#include "../../src/common/random.h"
@@ -75,12 +77,39 @@ class QuantileHistMaker: public TreeUpdater {
7577
HistMakerTrainParam hist_maker_param_;
7678
// training parameter
7779
xgboost::tree::TrainParam param_;
80+
// quantized data matrix
81+
common::GHistIndexMatrix gmat_;
82+
// (optional) data matrix with feature grouping
83+
// column accessor
84+
DMatrix const* p_last_dmat_ {nullptr};
85+
bool is_gmat_initialized_ {false};
7886

7987
xgboost::common::Monitor updater_monitor_;
8088

89+
template<typename GradientSumT>
90+
void SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>*, DMatrix *dmat);
91+
92+
template<typename GradientSumT>
93+
void CallUpdate(const std::unique_ptr<HistUpdater<GradientSumT>>& builder,
94+
xgboost::tree::TrainParam const *param,
95+
linalg::Matrix<GradientPair> *gpair,
96+
DMatrix *dmat,
97+
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
98+
const std::vector<RegTree *> &trees);
99+
100+
enum class HistPrecision {fp32, fp64};
101+
HistPrecision hist_precision_;
102+
103+
std::unique_ptr<HistUpdater<float>> pimpl_fp32;
104+
std::unique_ptr<HistUpdater<double>> pimpl_fp64;
105+
106+
FeatureInteractionConstraintHost int_constraint_;
107+
81108
::sycl::queue qu_;
82109
DeviceManager device_manager;
83110
ObjInfo const *task_{nullptr};
111+
112+
USMVector<GradientPair, MemoryType::on_device> gpair_device_;
84113
};
85114

86115

src/gbm/gbtree.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method)
5252
case TreeMethod::kAuto: // Use hist as default in 2.0
5353
case TreeMethod::kHist: {
5454
return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; },
55-
[] { return "grow_gpu_hist"; });
55+
[] { return "grow_gpu_hist"; },
56+
[] { return "grow_quantile_histmaker_sycl"; });
5657
}
5758
case TreeMethod::kApprox: {
5859
return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; });

0 commit comments

Comments
 (0)