@@ -49,9 +49,8 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
49
49
auto TestInitNewNode (int nid,
50
50
const common::GHistIndexMatrix& gmat,
51
51
const USMVector<GradientPair, MemoryType::on_device> &gpair,
52
- const DMatrix& fmat,
53
52
const RegTree& tree) {
54
- HistUpdater<GradientSumT>::InitNewNode (nid, gmat, gpair, fmat, tree);
53
+ HistUpdater<GradientSumT>::InitNewNode (nid, gmat, gpair, tree);
55
54
return HistUpdater<GradientSumT>::snode_host_[nid];
56
55
}
57
56
@@ -67,6 +66,16 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
67
66
RegTree* p_tree) {
68
67
HistUpdater<GradientSumT>::ApplySplit (nodes, gmat, p_tree);
69
68
}
69
+
70
+ auto TestExpandWithLossGuide (const common::GHistIndexMatrix& gmat,
71
+ DMatrix *p_fmat,
72
+ RegTree* p_tree,
73
+ const USMVector<GradientPair, MemoryType::on_device> &gpair) {
74
+ // HistUpdater<GradientSumT>::tree_evaluator_.Reset(HistUpdater<GradientSumT>::qu_,
75
+ // HistUpdater<GradientSumT>::param_,
76
+ // p_fmat->Info().num_col_);
77
+ HistUpdater<GradientSumT>::ExpandWithLossGuide (gmat, p_tree, gpair);
78
+ }
70
79
};
71
80
72
81
void GenerateRandomGPairs (::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
@@ -295,7 +304,7 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp
295
304
auto & row_idxs = row_set_collection->Data ();
296
305
const size_t * row_idxs_ptr = row_idxs.DataConst ();
297
306
updater.TestBuildHistogramsLossGuide (node, gmat, &tree, gpair);
298
- const auto snode = updater.TestInitNewNode (ExpandEntry::kRootNid , gmat, gpair, *p_fmat, tree);
307
+ const auto snode = updater.TestInitNewNode (ExpandEntry::kRootNid , gmat, gpair, tree);
299
308
300
309
GradStats<GradientSumT> grad_stat;
301
310
{
@@ -354,7 +363,7 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
354
363
auto & row_idxs = row_set_collection->Data ();
355
364
const size_t * row_idxs_ptr = row_idxs.DataConst ();
356
365
const auto * hist = updater.TestBuildHistogramsLossGuide (node, gmat, &tree, gpair);
357
- const auto snode_init = updater.TestInitNewNode (ExpandEntry::kRootNid , gmat, gpair, *p_fmat, tree);
366
+ const auto snode_init = updater.TestInitNewNode (ExpandEntry::kRootNid , gmat, gpair, tree);
358
367
359
368
const auto snode_updated = updater.TestEvaluateSplits ({node}, gmat, tree);
360
369
auto best_loss_chg = snode_updated[0 ].best .loss_chg ;
@@ -479,6 +488,53 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
479
488
480
489
}
481
490
491
+ template <typename GradientSumT>
492
+ void TestHistUpdaterExpandWithLossGuide (const xgboost::tree::TrainParam& param) {
493
+ const size_t num_rows = 3 ;
494
+ const size_t num_columns = 1 ;
495
+ const size_t n_bins = 16 ;
496
+
497
+ Context ctx;
498
+ ctx.UpdateAllowUnknown (Args{{" device" , " sycl" }});
499
+
500
+ DeviceManager device_manager;
501
+ auto qu = device_manager.GetQueue (ctx.Device ());
502
+
503
+ std::vector<float > data = {7 , 3 , 15 };
504
+ auto p_fmat = GetDMatrixFromData (data, num_rows, num_columns);
505
+
506
+ DeviceMatrix dmat;
507
+ dmat.Init (qu, p_fmat.get ());
508
+ common::GHistIndexMatrix gmat;
509
+ gmat.Init (qu, &ctx, dmat, n_bins);
510
+
511
+ std::vector<GradientPair> gpair_host = {{1 , 2 }, {3 , 1 }, {1 , 1 }};
512
+ USMVector<GradientPair, MemoryType::on_device> gpair (&qu, gpair_host);
513
+
514
+ RegTree tree;
515
+ FeatureInteractionConstraintHost int_constraints;
516
+ TestHistUpdater<GradientSumT> updater (&ctx, qu, param, int_constraints, p_fmat.get ());
517
+ updater.SetHistSynchronizer (new BatchHistSynchronizer<GradientSumT>());
518
+ updater.SetHistRowsAdder (new BatchHistRowsAdder<GradientSumT>());
519
+ auto * row_set_collection = updater.TestInitData (gmat, gpair, *p_fmat, tree);
520
+
521
+ updater.TestExpandWithLossGuide (gmat, p_fmat.get (), &tree, gpair);
522
+
523
+ const auto & nodes = tree.GetNodes ();
524
+ std::vector<float > ans (data.size ());
525
+ for (size_t data_idx = 0 ; data_idx < data.size (); ++data_idx) {
526
+ size_t node_idx = 0 ;
527
+ while (!nodes[node_idx].IsLeaf ()) {
528
+ node_idx = data[data_idx] < nodes[node_idx].SplitCond () ? nodes[node_idx].LeftChild () : nodes[node_idx].RightChild ();
529
+ }
530
+ ans[data_idx] = nodes[node_idx].LeafValue ();
531
+ }
532
+
533
+ ASSERT_NEAR (ans[0 ], -0.15 , 1e-6 );
534
+ ASSERT_NEAR (ans[1 ], -0.45 , 1e-6 );
535
+ ASSERT_NEAR (ans[2 ], -0.15 , 1e-6 );
536
+ }
537
+
482
538
TEST (SyclHistUpdater, Sampling) {
483
539
xgboost::tree::TrainParam param;
484
540
param.UpdateAllowUnknown (Args{{" subsample" , " 0.7" }});
@@ -546,4 +602,13 @@ TEST(SyclHistUpdater, ApplySplitDence) {
546
602
TestHistUpdaterApplySplit<double >(param, 0.0 , (1u << 16 ) + 1 );
547
603
}
548
604
605
+ TEST (SyclHistUpdater, ExpandWithLossGuide) {
606
+ xgboost::tree::TrainParam param;
607
+ param.UpdateAllowUnknown (Args{{" max_depth" , " 2" },
608
+ {" grow_policy" , " lossguide" }});
609
+
610
+ TestHistUpdaterExpandWithLossGuide<float >(param);
611
+ TestHistUpdaterExpandWithLossGuide<double >(param);
612
+ }
613
+
549
614
} // namespace xgboost::sycl::tree
0 commit comments