Skip to content

Commit c7e7ce7

Browse files
[SYCL] Add nodes initialisation (dmlc#10269)
--------- Co-authored-by: Dmitry Razdoburdin <> Co-authored-by: Jiaming Yuan <[email protected]>
1 parent 7a54ca4 commit c7e7ce7

File tree

7 files changed

+342
-33
lines changed

7 files changed

+342
-33
lines changed

plugin/sycl/device_manager.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22
* Copyright 2017-2023 by Contributors
33
* \file device_manager.cc
44
*/
5-
#pragma GCC diagnostic push
6-
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
7-
#pragma GCC diagnostic ignored "-W#pragma-messages"
8-
#pragma GCC diagnostic pop
9-
105
#include "../sycl/device_manager.h"
116

127
#include "../../src/collective/communicator-inl.h"

plugin/sycl/device_manager.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
#include <CL/sycl.hpp>
1414

15+
#pragma GCC diagnostic push
16+
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
17+
#pragma GCC diagnostic ignored "-W#pragma-messages"
1518
#include "xgboost/context.h"
19+
#pragma GCC diagnostic pop
1620

1721
namespace xgboost {
1822
namespace sycl {

plugin/sycl/objective/multiclass_obj.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,15 @@
33
* \file multiclass_obj.cc
44
* \brief Definition of multi-class classification objectives.
55
*/
6-
#pragma GCC diagnostic push
7-
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
8-
#pragma GCC diagnostic ignored "-W#pragma-messages"
9-
#pragma GCC diagnostic pop
10-
116
#include <vector>
127
#include <algorithm>
138
#include <limits>
149
#include <utility>
1510

16-
#include "xgboost/parameter.h"
1711
#pragma GCC diagnostic push
1812
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
13+
#pragma GCC diagnostic ignored "-W#pragma-messages"
14+
#include "xgboost/parameter.h"
1915
#include "xgboost/data.h"
2016
#include "../../src/common/math.h"
2117
#pragma GCC diagnostic pop

plugin/sycl/tree/hist_updater.cc

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <oneapi/dpl/random>
99

1010
#include "../common/hist_util.h"
11+
#include "../../src/collective/allreduce.h"
1112

1213
namespace xgboost {
1314
namespace sycl {
@@ -111,7 +112,6 @@ void HistUpdater<GradientSumT>::InitSampling(
111112

112113
template<typename GradientSumT>
113114
void HistUpdater<GradientSumT>::InitData(
114-
Context const * ctx,
115115
const common::GHistIndexMatrix& gmat,
116116
const USMVector<GradientPair, MemoryType::on_device> &gpair,
117117
const DMatrix& fmat,
@@ -215,6 +215,101 @@ void HistUpdater<GradientSumT>::InitData(
215215
data_layout_ = kSparseData;
216216
}
217217
}
218+
219+
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
220+
/* specialized code for dense data:
221+
choose the column that has a least positive number of discrete bins.
222+
For dense data (with no missing value),
223+
the sum of gradient histogram is equal to snode[nid] */
224+
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
225+
const auto nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
226+
uint32_t min_nbins_per_feature = 0;
227+
for (bst_uint i = 0; i < nfeature; ++i) {
228+
const uint32_t nbins = row_ptr[i + 1] - row_ptr[i];
229+
if (nbins > 0) {
230+
if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) {
231+
min_nbins_per_feature = nbins;
232+
fid_least_bins_ = i;
233+
}
234+
}
235+
}
236+
CHECK_GT(min_nbins_per_feature, 0U);
237+
}
238+
239+
std::fill(snode_host_.begin(), snode_host_.end(), NodeEntry<GradientSumT>(param_));
240+
builder_monitor_.Stop("InitData");
241+
}
242+
243+
template <typename GradientSumT>
244+
void HistUpdater<GradientSumT>::InitNewNode(int nid,
245+
const common::GHistIndexMatrix& gmat,
246+
const USMVector<GradientPair,
247+
MemoryType::on_device> &gpair,
248+
const DMatrix& fmat,
249+
const RegTree& tree) {
250+
builder_monitor_.Start("InitNewNode");
251+
252+
snode_host_.resize(tree.NumNodes(), NodeEntry<GradientSumT>(param_));
253+
{
254+
if (tree[nid].IsRoot()) {
255+
GradStats<GradientSumT> grad_stat;
256+
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
257+
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
258+
const uint32_t ibegin = row_ptr[fid_least_bins_];
259+
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
260+
const auto* hist = reinterpret_cast<GradStats<GradientSumT>*>(hist_[nid].Data());
261+
262+
std::vector<GradStats<GradientSumT>> ets(iend - ibegin);
263+
qu_.memcpy(ets.data(), hist + ibegin,
264+
(iend - ibegin) * sizeof(GradStats<GradientSumT>)).wait_and_throw();
265+
for (const auto& et : ets) {
266+
grad_stat += et;
267+
}
268+
} else {
269+
const common::RowSetCollection::Elem e = row_set_collection_[nid];
270+
const size_t* row_idxs = e.begin;
271+
const size_t size = e.Size();
272+
const GradientPair* gpair_ptr = gpair.DataConst();
273+
274+
::sycl::buffer<GradStats<GradientSumT>> buff(&grad_stat, 1);
275+
qu_.submit([&](::sycl::handler& cgh) {
276+
auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>());
277+
cgh.parallel_for<>(::sycl::range<1>(size), reduction,
278+
[=](::sycl::item<1> pid, auto& sum) {
279+
size_t i = pid.get_id(0);
280+
size_t row_idx = row_idxs[i];
281+
if constexpr (std::is_same<GradientPair::ValueT, GradientSumT>::value) {
282+
sum += gpair_ptr[row_idx];
283+
} else {
284+
sum += GradStats<GradientSumT>(gpair_ptr[row_idx].GetGrad(),
285+
gpair_ptr[row_idx].GetHess());
286+
}
287+
});
288+
}).wait_and_throw();
289+
}
290+
auto rc = collective::Allreduce(
291+
ctx_, linalg::MakeVec(reinterpret_cast<GradientSumT*>(&grad_stat), 2),
292+
collective::Op::kSum);
293+
SafeColl(rc);
294+
snode_host_[nid].stats = grad_stat;
295+
} else {
296+
int parent_id = tree[nid].Parent();
297+
if (tree[nid].IsLeftChild()) {
298+
snode_host_[nid].stats = snode_host_[parent_id].best.left_sum;
299+
} else {
300+
snode_host_[nid].stats = snode_host_[parent_id].best.right_sum;
301+
}
302+
}
303+
}
304+
305+
// calculating the weights
306+
{
307+
auto evaluator = tree_evaluator_.GetEvaluator();
308+
bst_uint parentid = tree[nid].Parent();
309+
snode_host_[nid].weight = evaluator.CalcWeight(parentid, snode_host_[nid].stats);
310+
snode_host_[nid].root_gain = evaluator.CalcGain(parentid, snode_host_[nid].stats);
311+
}
312+
builder_monitor_.Stop("InitNewNode");
218313
}
219314

220315
template class HistUpdater<float>;

plugin/sycl/tree/hist_updater.h

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,36 @@ namespace xgboost {
2626
namespace sycl {
2727
namespace tree {
2828

29+
// data structure
30+
template<typename GradType>
31+
struct NodeEntry {
32+
/*! \brief statics for node entry */
33+
GradStats<GradType> stats;
34+
/*! \brief loss of this node, without split */
35+
GradType root_gain;
36+
/*! \brief weight calculated related to current data */
37+
GradType weight;
38+
/*! \brief current best solution */
39+
SplitEntry<GradType> best;
40+
// constructor
41+
explicit NodeEntry(const xgboost::tree::TrainParam& param)
42+
: root_gain(0.0f), weight(0.0f) {}
43+
};
44+
2945
template<typename GradientSumT>
3046
class HistUpdater {
3147
public:
3248
template <MemoryType memory_type = MemoryType::shared>
3349
using GHistRowT = common::GHistRow<GradientSumT, memory_type>;
3450
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
3551

36-
explicit HistUpdater(::sycl::queue qu,
37-
const xgboost::tree::TrainParam& param,
38-
std::unique_ptr<TreeUpdater> pruner,
39-
FeatureInteractionConstraintHost int_constraints_,
40-
DMatrix const* fmat)
41-
: qu_(qu), param_(param),
52+
explicit HistUpdater(const Context* ctx,
53+
::sycl::queue qu,
54+
const xgboost::tree::TrainParam& param,
55+
std::unique_ptr<TreeUpdater> pruner,
56+
FeatureInteractionConstraintHost int_constraints_,
57+
DMatrix const* fmat)
58+
: ctx_(ctx), qu_(qu), param_(param),
4259
tree_evaluator_(qu, param, fmat->Info().num_col_),
4360
pruner_(std::move(pruner)),
4461
interaction_constraints_{std::move(int_constraints_)},
@@ -61,8 +78,7 @@ class HistUpdater {
6178
USMVector<size_t, MemoryType::on_device>* row_indices);
6279

6380

64-
void InitData(Context const * ctx,
65-
const common::GHistIndexMatrix& gmat,
81+
void InitData(const common::GHistIndexMatrix& gmat,
6682
const USMVector<GradientPair, MemoryType::on_device> &gpair,
6783
const DMatrix& fmat,
6884
const RegTree& tree);
@@ -78,6 +94,12 @@ class HistUpdater {
7894
data_layout_ != kSparseData, hist_buffer, event_priv);
7995
}
8096

97+
void InitNewNode(int nid,
98+
const common::GHistIndexMatrix& gmat,
99+
const USMVector<GradientPair, MemoryType::on_device> &gpair,
100+
const DMatrix& fmat,
101+
const RegTree& tree);
102+
81103
void BuildLocalHistograms(const common::GHistIndexMatrix &gmat,
82104
RegTree *p_tree,
83105
const USMVector<GradientPair, MemoryType::on_device> &gpair);
@@ -89,6 +111,7 @@ class HistUpdater {
89111
const USMVector<GradientPair, MemoryType::on_device> &gpair);
90112

91113
// --data fields--
114+
const Context* ctx_;
92115
size_t sub_group_size_;
93116

94117
// the internal row sets
@@ -113,9 +136,16 @@ class HistUpdater {
113136
/*! \brief culmulative histogram of gradients. */
114137
common::HistCollection<GradientSumT, MemoryType::on_device> hist_;
115138

139+
/*! \brief TreeNode Data: statistics for each constructed node */
140+
std::vector<NodeEntry<GradientSumT>> snode_host_;
141+
116142
xgboost::common::Monitor builder_monitor_;
117143
xgboost::common::Monitor kernel_monitor_;
118144

145+
/*! \brief feature with least # of bins. to be used for dense specialization
146+
of InitNewNode() */
147+
uint32_t fid_least_bins_;
148+
119149
uint64_t seed_ = 0;
120150

121151
// key is the node id which should be calculated by Subtraction Trick, value is the node which

plugin/sycl/tree/param.h

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,115 @@ struct TrainParam {
4949
template <typename GradType>
5050
using GradStats = xgboost::detail::GradientPairInternal<GradType>;
5151

52+
/*!
53+
* \brief SYCL implementation of SplitEntryContainer for device compilation.
54+
* Original structure cannot be used due 'cat_bits' field of type std::vector<uint32_t>,
55+
* which is not device-copyable
56+
*/
57+
template<typename GradientT>
58+
struct SplitEntryContainer {
59+
/*! \brief loss change after split this node */
60+
bst_float loss_chg {0.0f};
61+
/*! \brief split index */
62+
bst_feature_t sindex{0};
63+
bst_float split_value{0.0f};
64+
65+
66+
GradientT left_sum;
67+
GradientT right_sum;
68+
69+
70+
SplitEntryContainer() = default;
71+
72+
73+
friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
74+
os << "loss_chg: " << s.loss_chg << ", "
75+
<< "split index: " << s.SplitIndex() << ", "
76+
<< "split value: " << s.split_value << ", "
77+
<< "left_sum: " << s.left_sum << ", "
78+
<< "right_sum: " << s.right_sum;
79+
return os;
80+
}
81+
/*!\return feature index to split on */
82+
bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
83+
/*!\return whether missing value goes to left branch */
84+
bool DefaultLeft() const { return (sindex >> 31) != 0; }
85+
/*!
86+
* \brief decides whether we can replace current entry with the given statistics
87+
*
88+
* This function gives better priority to lower index when loss_chg == new_loss_chg.
89+
* Not the best way, but helps to give consistent result during multi-thread
90+
* execution.
91+
*
92+
* \param new_loss_chg the loss reduction get through the split
93+
* \param split_index the feature index where the split is on
94+
*/
95+
inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
96+
if (::sycl::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf,
97+
// for example when lambda = 0 & min_child_weight = 0
98+
// skip value in this case
99+
return false;
100+
} else if (this->SplitIndex() <= split_index) {
101+
return new_loss_chg > this->loss_chg;
102+
} else {
103+
return !(this->loss_chg > new_loss_chg);
104+
}
105+
}
106+
/*!
107+
* \brief update the split entry, replace it if e is better
108+
* \param e candidate split solution
109+
* \return whether the proposed split is better and can replace current split
110+
*/
111+
inline bool Update(const SplitEntryContainer &e) {
112+
if (this->NeedReplace(e.loss_chg, e.SplitIndex())) {
113+
this->loss_chg = e.loss_chg;
114+
this->sindex = e.sindex;
115+
this->split_value = e.split_value;
116+
this->left_sum = e.left_sum;
117+
this->right_sum = e.right_sum;
118+
return true;
119+
} else {
120+
return false;
121+
}
122+
}
123+
/*!
124+
* \brief update the split entry, replace it if e is better
125+
* \param new_loss_chg loss reduction of new candidate
126+
* \param split_index feature index to split on
127+
* \param new_split_value the split point
128+
* \param default_left whether the missing value goes to left
129+
* \return whether the proposed split is better and can replace current split
130+
*/
131+
bool Update(bst_float new_loss_chg, unsigned split_index,
132+
bst_float new_split_value, bool default_left,
133+
const GradientT &left_sum,
134+
const GradientT &right_sum) {
135+
if (this->NeedReplace(new_loss_chg, split_index)) {
136+
this->loss_chg = new_loss_chg;
137+
if (default_left) {
138+
split_index |= (1U << 31);
139+
}
140+
this->sindex = split_index;
141+
this->split_value = new_split_value;
142+
this->left_sum = left_sum;
143+
this->right_sum = right_sum;
144+
return true;
145+
} else {
146+
return false;
147+
}
148+
}
149+
150+
151+
/*! \brief same as update, used by AllReduce*/
152+
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
153+
const SplitEntryContainer &src) { // NOLINT(*)
154+
dst.Update(src);
155+
}
156+
};
157+
158+
template<typename GradType>
159+
using SplitEntry = SplitEntryContainer<GradStats<GradType>>;
160+
52161
} // namespace tree
53162
} // namespace sycl
54163
} // namespace xgboost

0 commit comments

Comments
 (0)