@@ -73,6 +73,13 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
73
73
const USMVector<GradientPair, MemoryType::on_device> &gpair) {
74
74
HistUpdater<GradientSumT>::ExpandWithLossGuide (gmat, p_tree, gpair);
75
75
}
76
+
77
+ auto TestExpandWithDepthWise (const common::GHistIndexMatrix& gmat,
78
+ DMatrix *p_fmat,
79
+ RegTree* p_tree,
80
+ const USMVector<GradientPair, MemoryType::on_device> &gpair) {
81
+ HistUpdater<GradientSumT>::ExpandWithDepthWise (gmat, p_tree, gpair);
82
+ }
76
83
};
77
84
78
85
void GenerateRandomGPairs (::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
@@ -532,6 +539,55 @@ void TestHistUpdaterExpandWithLossGuide(const xgboost::tree::TrainParam& param)
532
539
ASSERT_NEAR (ans[2 ], -0.15 , 1e-6 );
533
540
}
534
541
542
+ template <typename GradientSumT>
543
+ void TestHistUpdaterExpandWithDepthWise (const xgboost::tree::TrainParam& param) {
544
+ const size_t num_rows = 3 ;
545
+ const size_t num_columns = 1 ;
546
+ const size_t n_bins = 16 ;
547
+
548
+ Context ctx;
549
+ ctx.UpdateAllowUnknown (Args{{" device" , " sycl" }});
550
+
551
+ DeviceManager device_manager;
552
+ auto qu = device_manager.GetQueue (ctx.Device ());
553
+
554
+ std::vector<float > data = {7 , 3 , 15 };
555
+ auto p_fmat = GetDMatrixFromData (data, num_rows, num_columns);
556
+
557
+ DeviceMatrix dmat;
558
+ dmat.Init (qu, p_fmat.get ());
559
+ common::GHistIndexMatrix gmat;
560
+ gmat.Init (qu, &ctx, dmat, n_bins);
561
+
562
+ std::vector<GradientPair> gpair_host = {{1 , 2 }, {3 , 1 }, {1 , 1 }};
563
+ USMVector<GradientPair, MemoryType::on_device> gpair (&qu, gpair_host);
564
+
565
+ RegTree tree;
566
+ FeatureInteractionConstraintHost int_constraints;
567
+ ObjInfo task{ObjInfo::kRegression };
568
+ std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create (" prune" , &ctx, &task)};
569
+ TestHistUpdater<GradientSumT> updater (&ctx, qu, param, std::move (pruner), int_constraints, p_fmat.get ());
570
+ updater.SetHistSynchronizer (new BatchHistSynchronizer<GradientSumT>());
571
+ updater.SetHistRowsAdder (new BatchHistRowsAdder<GradientSumT>());
572
+ auto * row_set_collection = updater.TestInitData (gmat, gpair, *p_fmat, tree);
573
+
574
+ updater.TestExpandWithDepthWise (gmat, p_fmat.get (), &tree, gpair);
575
+
576
+ const auto & nodes = tree.GetNodes ();
577
+ std::vector<float > ans (data.size ());
578
+ for (size_t data_idx = 0 ; data_idx < data.size (); ++data_idx) {
579
+ size_t node_idx = 0 ;
580
+ while (!nodes[node_idx].IsLeaf ()) {
581
+ node_idx = data[data_idx] < nodes[node_idx].SplitCond () ? nodes[node_idx].LeftChild () : nodes[node_idx].RightChild ();
582
+ }
583
+ ans[data_idx] = nodes[node_idx].LeafValue ();
584
+ }
585
+
586
+ ASSERT_NEAR (ans[0 ], -0.15 , 1e-6 );
587
+ ASSERT_NEAR (ans[1 ], -0.45 , 1e-6 );
588
+ ASSERT_NEAR (ans[2 ], -0.15 , 1e-6 );
589
+ }
590
+
535
591
TEST (SyclHistUpdater, Sampling) {
536
592
xgboost::tree::TrainParam param;
537
593
param.UpdateAllowUnknown (Args{{" subsample" , " 0.7" }});
@@ -608,4 +664,12 @@ TEST(SyclHistUpdater, ExpandWithLossGuide) {
608
664
TestHistUpdaterExpandWithLossGuide<double >(param);
609
665
}
610
666
667
+ TEST (SyclHistUpdater, ExpandWithDepthWise) {
668
+ xgboost::tree::TrainParam param;
669
+ param.UpdateAllowUnknown (Args{{" max_depth" , " 2" }});
670
+
671
+ TestHistUpdaterExpandWithDepthWise<float >(param);
672
+ TestHistUpdaterExpandWithDepthWise<double >(param);
673
+ }
674
+
611
675
} // namespace xgboost::sycl::tree
0 commit comments