Skip to content

Commit 1ab6c0a

Browse files
authored
Fix sampling (#54)
* fix samling; minor refactoring * fix --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent 6bc5599 commit 1ab6c0a

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed

plugin/sycl/predictor/predictor.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22
* Copyright by Contributors 2017-2024
33
*/
44
#include <dmlc/timer.h>
5-
// #pragma GCC diagnostic push
6-
// #pragma GCC diagnostic ignored "-Wtautological-constant-compare"
7-
// #pragma GCC diagnostic ignored "-W#pragma-messages"
8-
// #pragma GCC diagnostic pop
95

106
#include <cstddef>
117
#include <limits>

plugin/sycl/tree/hist_updater.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
336336

337337
template <typename GradientSumT>
338338
void HistUpdater<GradientSumT>::Update(
339-
Context const * ctx,
340339
xgboost::tree::TrainParam const *param,
341340
const common::GHistIndexMatrix &gmat,
342341
linalg::Matrix<GradientPair> *gpair,
@@ -485,6 +484,10 @@ void HistUpdater<GradientSumT>::InitData(
485484
builder_monitor_.Start("InitData");
486485
const auto& info = fmat.Info();
487486

487+
if (!column_sampler_) {
488+
column_sampler_ = xgboost::common::MakeColumnSampler(ctx_);
489+
}
490+
488491
// initialize the row set
489492
{
490493
row_set_collection_.Clear();
@@ -575,9 +578,9 @@ void HistUpdater<GradientSumT>::InitData(
575578

576579
// store a pointer to the tree
577580
p_last_tree_ = &tree;
578-
column_sampler_.Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(),
579-
param_.colsample_bynode, param_.colsample_bylevel,
580-
param_.colsample_bytree);
581+
column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(),
582+
param_.colsample_bynode, param_.colsample_bylevel,
583+
param_.colsample_bytree);
581584
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
582585
/* specialized code for dense data:
583586
choose the column that has a least positive number of discrete bins.
@@ -597,8 +600,10 @@ void HistUpdater<GradientSumT>::InitData(
597600
}
598601
CHECK_GT(min_nbins_per_feature, 0U);
599602
}
600-
601-
std::fill(snode_host_.begin(), snode_host_.end(), NodeEntry<GradientSumT>(param_));
603+
{
604+
qu_.wait_and_throw();
605+
std::fill(snode_host_.begin(), snode_host_.end(), NodeEntry<GradientSumT>(param_));
606+
}
602607

603608
{
604609
if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) {
@@ -628,7 +633,7 @@ void HistUpdater<GradientSumT>::EvaluateSplits(
628633
size_t total_features = 0;
629634
for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
630635
const int32_t nid = nodes_set[nid_in_set].nid;
631-
features_sets[nid_in_set] = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
636+
features_sets[nid_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nid));
632637
for (size_t idx = 0; idx < features_sets[nid_in_set]->Size(); idx++) {
633638
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx];
634639
if (interaction_constraints_.Query(nid, fid)) {

plugin/sycl/tree/hist_updater.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,19 @@ class HistUpdater {
6262
tree_evaluator_(qu, param, fmat->Info().num_col_),
6363
pruner_(std::move(pruner)),
6464
interaction_constraints_{std::move(int_constraints_)},
65-
p_last_tree_(nullptr), p_last_fmat_(fmat),
66-
column_sampler_(seed_) {
65+
p_last_tree_(nullptr), p_last_fmat_(fmat) {
6766
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
6867
kernel_monitor_.Init("SYCL::Quantile::HistUpdater");
68+
if (param.max_depth > 0) {
69+
snode_device_.Resize(&qu, 1u << (param.max_depth + 1));
70+
}
6971
const auto sub_group_sizes =
7072
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
7173
sub_group_size_ = sub_group_sizes.back();
7274
}
7375

7476
// update one tree, growing
75-
void Update(Context const * ctx,
76-
xgboost::tree::TrainParam const *param,
77+
void Update(xgboost::tree::TrainParam const *param,
7778
const common::GHistIndexMatrix &gmat,
7879
linalg::Matrix<GradientPair> *gpair,
7980
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
@@ -181,11 +182,6 @@ class HistUpdater {
181182
RegTree *p_tree,
182183
const USMVector<GradientPair, MemoryType::on_device> &gpair);
183184

184-
void ExpandWithLossGuide(const common::GHistIndexMatrix& gmat,
185-
DMatrix* p_fmat,
186-
RegTree* p_tree,
187-
const USMVector<GradientPair, MemoryType::on_device>& gpair);
188-
189185
void EvaluateAndApplySplits(const common::GHistIndexMatrix &gmat,
190186
RegTree *p_tree,
191187
int *num_leaves,
@@ -200,6 +196,11 @@ class HistUpdater {
200196
std::vector<ExpandEntry>* nodes_for_apply_split,
201197
std::vector<ExpandEntry>* temp_qexpand_depth);
202198

199+
void ExpandWithLossGuide(const common::GHistIndexMatrix& gmat,
200+
DMatrix* p_fmat,
201+
RegTree* p_tree,
202+
const USMVector<GradientPair, MemoryType::on_device>& gpair);
203+
203204
void ReduceHists(const std::vector<int>& sync_ids, size_t nbins);
204205

205206
inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
@@ -213,13 +214,14 @@ class HistUpdater {
213214
// --data fields--
214215
const Context* ctx_;
215216
size_t sub_group_size_;
217+
const xgboost::tree::TrainParam& param_;
218+
std::shared_ptr<xgboost::common::ColumnSampler> column_sampler_;
216219

217220
// the internal row sets
218221
common::RowSetCollection row_set_collection_;
219222
std::vector<SplitQuery> split_queries_host_;
220223
USMVector<SplitQuery, MemoryType::on_device> split_queries_device_;
221224

222-
const xgboost::tree::TrainParam& param_;
223225
TreeEvaluator<GradientSumT> tree_evaluator_;
224226
std::unique_ptr<TreeUpdater> pruner_;
225227
FeatureInteractionConstraintHost interaction_constraints_;
@@ -258,7 +260,6 @@ class HistUpdater {
258260
uint32_t fid_least_bins_;
259261

260262
uint64_t seed_ = 0;
261-
xgboost::common::ColumnSampler column_sampler_;
262263

263264
common::PartitionBuilder partition_builder_;
264265

plugin/sycl/tree/updater_quantile_hist.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void QuantileHistMaker::CallUpdate(
8080
qu_.wait();
8181

8282
for (auto tree : trees) {
83-
pimpl->Update(ctx_, param, gmat_, gpair, gpair_device_, dmat, out_position, tree);
83+
pimpl->Update(param, gmat_, gpair, gpair_device_, dmat, out_position, tree);
8484
}
8585
}
8686

0 commit comments

Comments
 (0)