Skip to content

Commit 916d4a4

Browse files
authored
Add tests for EvaluateSplits (#59)
* minor refactoring * optimize host-device memory sync * add test for EvaluateSplits * linting --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent ba88551 commit 916d4a4

File tree

4 files changed

+124
-39
lines changed

4 files changed

+124
-39
lines changed

plugin/sycl/tree/hist_updater.cc

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void HistUpdater<GradientSumT>::EvaluateAndApplySplits(
183183
int *num_leaves,
184184
int depth,
185185
std::vector<ExpandEntry> *temp_qexpand_depth) {
186-
EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree);
186+
EvaluateSplits(qexpand_depth_wise_, gmat, *p_tree);
187187

188188
std::vector<ExpandEntry> nodes_for_apply_split;
189189
AddSplitsToTree(gmat, p_tree, num_leaves, depth,
@@ -280,7 +280,7 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
280280

281281
this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, *p_tree);
282282

283-
this->EvaluateSplits({node}, gmat, hist_, *p_tree);
283+
this->EvaluateSplits({node}, gmat, *p_tree);
284284
node.split.loss_chg = snode_host_[ExpandEntry::kRootNid].best.loss_chg;
285285

286286
qexpand_loss_guided_->push(node);
@@ -325,7 +325,7 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
325325
snode_host_[cleft].weight, snode_host_[cright].weight);
326326
interaction_constraints_.Split(nid, featureid, cleft, cright);
327327

328-
this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree);
328+
this->EvaluateSplits({left_node, right_node}, gmat, *p_tree);
329329
left_node.split.loss_chg = snode_host_[cleft].best.loss_chg;
330330
right_node.split.loss_chg = snode_host_[cright].best.loss_chg;
331331

@@ -472,7 +472,7 @@ void HistUpdater<GradientSumT>::InitSampling(
472472
});
473473
});
474474
} else {
475-
// Use oneDPL uniform for better perf, as far as bernoulli_distribution uses fp64
475+
// Use oneDPL uniform, as far as bernoulli_distribution uses fp64
476476
event = qu_.submit([&](::sycl::handler& cgh) {
477477
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
478478
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
@@ -649,45 +649,32 @@ template<typename GradientSumT>
649649
void HistUpdater<GradientSumT>::EvaluateSplits(
650650
const std::vector<ExpandEntry>& nodes_set,
651651
const common::GHistIndexMatrix& gmat,
652-
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
653652
const RegTree& tree) {
654653
builder_monitor_.Start("EvaluateSplits");
655654

656655
const size_t n_nodes_in_set = nodes_set.size();
657656

658657
using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t>>;
659-
std::vector<FeatureSetType> features_sets(n_nodes_in_set);
660658

661659
// Generate feature set for each tree node
662-
size_t total_features = 0;
663-
for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
664-
const int32_t nid = nodes_set[nid_in_set].nid;
665-
features_sets[nid_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nid));
666-
for (size_t idx = 0; idx < features_sets[nid_in_set]->Size(); idx++) {
667-
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx];
668-
if (interaction_constraints_.Query(nid, fid)) {
669-
total_features++;
670-
}
671-
}
672-
}
673-
674-
split_queries_host_.resize(total_features);
675660
size_t pos = 0;
676-
677661
for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
678-
const size_t nid = nodes_set[nid_in_set].nid;
679-
680-
for (size_t idx = 0; idx < features_sets[nid_in_set]->Size(); idx++) {
681-
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx];
662+
const bst_node_t nid = nodes_set[nid_in_set].nid;
663+
FeatureSetType features_set = column_sampler_->GetFeatureSet(tree.GetDepth(nid));
664+
for (size_t idx = 0; idx < features_set->Size(); idx++) {
665+
const size_t fid = features_set->ConstHostVector()[idx];
682666
if (interaction_constraints_.Query(nid, fid)) {
683-
split_queries_host_[pos].nid = nid;
684-
split_queries_host_[pos].fid = fid;
685-
split_queries_host_[pos].hist = hist[nid].DataConst();
686-
split_queries_host_[pos].best = snode_host_[nid].best;
687-
pos++;
667+
auto this_hist = hist_[nid].DataConst();
668+
if (pos < split_queries_host_.size()) {
669+
split_queries_host_[pos] = SplitQuery{nid, fid, this_hist};
670+
} else {
671+
split_queries_host_.push_back({nid, fid, this_hist});
672+
}
673+
++pos;
688674
}
689675
}
690676
}
677+
const size_t total_features = pos;
691678

692679
split_queries_device_.Resize(&qu_, total_features);
693680
auto event = qu_.memcpy(split_queries_device_.Data(), split_queries_host_.data(),
@@ -702,10 +689,14 @@ void HistUpdater<GradientSumT>::EvaluateSplits(
702689
snode_device_.ResizeNoCopy(&qu_, snode_host_.size());
703690
event = qu_.memcpy(snode_device_.Data(), snode_host_.data(),
704691
snode_host_.size() * sizeof(NodeEntry<GradientSumT>), event);
705-
const NodeEntry<GradientSumT>* snode = snode_device_.DataConst();
692+
const NodeEntry<GradientSumT>* snode = snode_device_.Data();
706693

707694
const float min_child_weight = param_.min_child_weight;
708695

696+
best_splits_device_.ResizeNoCopy(&qu_, total_features);
697+
if (best_splits_host_.size() < total_features) best_splits_host_.resize(total_features);
698+
SplitEntry<GradientSumT>* best_splits = best_splits_device_.Data();
699+
709700
event = qu_.submit([&](::sycl::handler& cgh) {
710701
cgh.depends_on(event);
711702
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, sub_group_size_),
@@ -717,17 +708,18 @@ void HistUpdater<GradientSumT>::EvaluateSplits(
717708
int fid = split_queries_device[i].fid;
718709
const GradientPairT* hist_data = split_queries_device[i].hist;
719710

711+
best_splits[i] = snode[nid].best;
720712
EnumerateSplit(sg, cut_ptr, cut_val, hist_data, snode[nid],
721-
&(split_queries_device[i].best), fid, nid, evaluator, min_child_weight);
713+
&(best_splits[i]), fid, nid, evaluator, min_child_weight);
722714
});
723715
});
724-
event = qu_.memcpy(split_queries_host_.data(), split_queries_device_.Data(),
725-
total_features * sizeof(SplitQuery), event);
716+
event = qu_.memcpy(best_splits_host_.data(), best_splits,
717+
total_features * sizeof(SplitEntry<GradientSumT>), event);
726718

727719
qu_.wait();
728720
for (size_t i = 0; i < total_features; i++) {
729721
int nid = split_queries_host_[i].nid;
730-
snode_host_[nid].best.Update(split_queries_host_[i].best);
722+
snode_host_[nid].best.Update(best_splits_host_[i]);
731723
}
732724

733725
builder_monitor_.Stop("EvaluateSplits");

plugin/sycl/tree/hist_updater.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,8 @@ class HistUpdater {
9595
friend class DistributedHistRowsAdder<GradientSumT>;
9696

9797
struct SplitQuery {
98-
int nid;
99-
int fid;
100-
SplitEntry<GradientSumT> best;
98+
bst_node_t nid;
99+
size_t fid;
101100
const GradientPairT* hist;
102101
};
103102

@@ -106,7 +105,6 @@ class HistUpdater {
106105

107106
void EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
108107
const common::GHistIndexMatrix& gmat,
109-
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
110108
const RegTree& tree);
111109

112110
// Enumerate the split values of specific feature
@@ -222,6 +220,9 @@ class HistUpdater {
222220
std::vector<SplitQuery> split_queries_host_;
223221
USMVector<SplitQuery, MemoryType::on_device> split_queries_device_;
224222

223+
USMVector<SplitEntry<GradientSumT>, MemoryType::on_device> best_splits_device_;
224+
std::vector<SplitEntry<GradientSumT>> best_splits_host_;
225+
225226
TreeEvaluator<GradientSumT> tree_evaluator_;
226227
FeatureInteractionConstraintHost interaction_constraints_;
227228

tests/ci_build/conda_env/linux_sycl_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: linux_sycl_test
22
channels:
33
- conda-forge
4-
- intel
4+
- https://software.repos.intel.com/python/conda/
55
dependencies:
66
- python=3.8
77
- cmake

tests/cpp/plugin/test_sycl_hist_updater.cc

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
5252
HistUpdater<GradientSumT>::InitNewNode(nid, gmat, gpair, fmat, tree);
5353
return HistUpdater<GradientSumT>::snode_host_[nid];
5454
}
55+
56+
auto TestEvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
57+
const common::GHistIndexMatrix& gmat,
58+
const RegTree& tree) {
59+
HistUpdater<GradientSumT>::EvaluateSplits(nodes_set, gmat, tree);
60+
return HistUpdater<GradientSumT>::snode_host_;
61+
}
5562
};
5663

5764
void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
@@ -301,6 +308,83 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp
301308
EXPECT_NEAR(snode.stats.GetHess(), grad_stat.GetHess(), 1e-6 * grad_stat.GetHess());
302309
}
303310

311+
template <typename GradientSumT>
312+
void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
313+
const size_t num_rows = 1u << 8;
314+
const size_t num_columns = 2;
315+
const size_t n_bins = 32;
316+
317+
Context ctx;
318+
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
319+
320+
DeviceManager device_manager;
321+
auto qu = device_manager.GetQueue(ctx.Device());
322+
ObjInfo task{ObjInfo::kRegression};
323+
324+
auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0f}.GenerateDMatrix();
325+
326+
FeatureInteractionConstraintHost int_constraints;
327+
328+
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
329+
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
330+
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
331+
332+
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
333+
auto* gpair_ptr = gpair.Data();
334+
GenerateRandomGPairs(&qu, gpair_ptr, num_rows, false);
335+
336+
DeviceMatrix dmat;
337+
dmat.Init(qu, p_fmat.get());
338+
common::GHistIndexMatrix gmat;
339+
gmat.Init(qu, &ctx, dmat, n_bins);
340+
341+
RegTree tree;
342+
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
343+
ExpandEntry node(ExpandEntry::kRootNid, tree.GetDepth(ExpandEntry::kRootNid));
344+
345+
auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);
346+
auto& row_idxs = row_set_collection->Data();
347+
const size_t* row_idxs_ptr = row_idxs.DataConst();
348+
const auto* hist = updater.TestBuildHistogramsLossGuide(node, gmat, &tree, gpair);
349+
const auto snode_init = updater.TestInitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, tree);
350+
351+
const auto snode_updated = updater.TestEvaluateSplits({node}, gmat, tree);
352+
auto best_loss_chg = snode_updated[0].best.loss_chg;
353+
auto stats = snode_init.stats;
354+
auto root_gain = snode_init.root_gain;
355+
356+
// Check all splits manually. Save the best one and compare with the ans
357+
TreeEvaluator<GradientSumT> tree_evaluator(qu, param, num_columns);
358+
auto evaluator = tree_evaluator.GetEvaluator();
359+
const uint32_t* cut_ptr = gmat.cut_device.Ptrs().DataConst();
360+
const size_t size = gmat.cut_device.Ptrs().Size();
361+
int n_better_splits = 0;
362+
const auto* hist_ptr = (*hist)[0].DataConst();
363+
std::vector<bst_float> best_loss_chg_des(1, -1);
364+
{
365+
::sycl::buffer<bst_float> best_loss_chg_buff(best_loss_chg_des.data(), 1);
366+
qu.submit([&](::sycl::handler& cgh) {
367+
auto best_loss_chg_acc = best_loss_chg_buff.template get_access<::sycl::access::mode::read_write>(cgh);
368+
cgh.single_task<>([=]() {
369+
for (size_t i = 1; i < size; ++i) {
370+
GradStats<GradientSumT> left(0, 0);
371+
GradStats<GradientSumT> right = stats - left;
372+
for (size_t j = cut_ptr[i-1]; j < cut_ptr[i]; ++j) {
373+
auto loss_change = evaluator.CalcSplitGain(0, i - 1, left, right) - root_gain;
374+
if (loss_change > best_loss_chg_acc[0]) {
375+
best_loss_chg_acc[0] = loss_change;
376+
}
377+
left.Add(hist_ptr[j].GetGrad(), hist_ptr[j].GetHess());
378+
right = stats - left;
379+
}
380+
}
381+
});
382+
}).wait();
383+
}
384+
385+
ASSERT_NEAR(best_loss_chg_des[0], best_loss_chg, 1e-6);
386+
}
387+
304388
TEST(SyclHistUpdater, Sampling) {
305389
xgboost::tree::TrainParam param;
306390
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});
@@ -340,4 +424,12 @@ TEST(SyclHistUpdater, InitNewNode) {
340424
TestHistUpdaterInitNewNode<double>(param, 0.5);
341425
}
342426

427+
TEST(SyclHistUpdater, EvaluateSplits) {
428+
xgboost::tree::TrainParam param;
429+
param.UpdateAllowUnknown(Args{{"max_depth", "3"}});
430+
431+
TestHistUpdaterEvaluateSplits<float>(param);
432+
TestHistUpdaterEvaluateSplits<double>(param);
433+
}
434+
343435
} // namespace xgboost::sycl::tree

0 commit comments

Comments
 (0)