diff --git a/plugin/sycl/tree/hist_updater.cc b/plugin/sycl/tree/hist_updater.cc index da39cd8a723f..90d8089e40c4 100644 --- a/plugin/sycl/tree/hist_updater.cc +++ b/plugin/sycl/tree/hist_updater.cc @@ -13,6 +13,8 @@ #include #include +#include "../../src/tree/common_row_partitioner.h" + #include "../common/hist_util.h" #include "../../src/collective/allreduce.h" @@ -188,7 +190,7 @@ void HistUpdater::EvaluateAndApplySplits( std::vector nodes_for_apply_split; AddSplitsToTree(gmat, p_tree, num_leaves, depth, &nodes_for_apply_split, temp_qexpand_depth); - ApplySplit(nodes_for_apply_split, gmat, hist_, p_tree); + ApplySplit(nodes_for_apply_split, gmat, p_tree); } // Split nodes to 2 sets depending on amount of rows in each node @@ -304,7 +306,7 @@ void HistUpdater::ExpandWithLossGuide( right_leaf_weight, e.best.loss_chg, e.stats.GetHess(), e.best.left_sum.GetHess(), e.best.right_sum.GetHess()); - this->ApplySplit({candidate}, gmat, hist_, p_tree); + this->ApplySplit({candidate}, gmat, p_tree); const int cleft = (*p_tree)[nid].LeftChild(); const int cright = (*p_tree)[nid].RightChild(); @@ -786,34 +788,6 @@ void HistUpdater::EnumerateSplit( best.SplitIndex() == total_split_index) p_best->Update(best); } -template -void HistUpdater::FindSplitConditions( - const std::vector& nodes, - const RegTree& tree, - const common::GHistIndexMatrix& gmat, - std::vector* split_conditions) { - const size_t n_nodes = nodes.size(); - split_conditions->resize(n_nodes); - - for (size_t i = 0; i < nodes.size(); ++i) { - const int32_t nid = nodes[i].nid; - const bst_uint fid = tree[nid].SplitIndex(); - const bst_float split_pt = tree[nid].SplitCond(); - const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; - const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; - int32_t split_cond = -1; - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - CHECK_LT(upper_bound, - static_cast(std::numeric_limits::max())); - for (uint32_t i = lower_bound; i < upper_bound; ++i) { - if (split_pt == gmat.cut.Values()[i]) { - split_cond = static_cast(i); - } - } - (*split_conditions)[i] = split_cond; - } -} template void HistUpdater::AddSplitsToRowSet( const std::vector& nodes, @@ -833,13 +807,13 @@ template void HistUpdater::ApplySplit( const std::vector nodes, const common::GHistIndexMatrix& gmat, - const common::HistCollection& hist, RegTree* p_tree) { + using CommonRowPartitioner = xgboost::tree::CommonRowPartitioner; builder_monitor_.Start("ApplySplit"); const size_t n_nodes = nodes.size(); - std::vector split_conditions; - FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); + std::vector split_conditions(n_nodes); + CommonRowPartitioner::FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); partition_builder_.Init(&qu_, n_nodes, [&](size_t node_in_set) { const int32_t nid = nodes[node_in_set].nid; diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h index 92e960e668b7..67f4b0090711 100644 --- a/plugin/sycl/tree/hist_updater.h +++ b/plugin/sycl/tree/hist_updater.h @@ -119,16 +119,10 @@ class HistUpdater { void ApplySplit(std::vector nodes, const common::GHistIndexMatrix& gmat, - const common::HistCollection& hist, RegTree* p_tree); void AddSplitsToRowSet(const std::vector& nodes, RegTree* p_tree); - - void FindSplitConditions(const std::vector& nodes, const RegTree& tree, - const common::GHistIndexMatrix& gmat, - std::vector* split_conditions); - void InitData(const common::GHistIndexMatrix& gmat, const USMVector &gpair, const DMatrix& fmat, diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index c3065ad5f135..319aba052353 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -110,9 +110,10 @@ class CommonRowPartitioner { } } - template - void FindSplitConditions(const std::vector& nodes, const RegTree& tree, - const GHistIndexMatrix& gmat, std::vector* split_conditions) { + /* Making GHistIndexMatrix_t a templete parameter allows reuse this function for sycl-plugin */ + template + static void FindSplitConditions(const std::vector& nodes, const RegTree& tree, + const GHistIndexMatrix_t& gmat, std::vector* split_conditions) { auto const& ptrs = gmat.cut.Ptrs(); auto const& vals = gmat.cut.Values(); diff --git a/tests/cpp/plugin/test_sycl_hist_updater.cc b/tests/cpp/plugin/test_sycl_hist_updater.cc index 24270e76143f..64cabd4052cf 100644 --- a/tests/cpp/plugin/test_sycl_hist_updater.cc +++ b/tests/cpp/plugin/test_sycl_hist_updater.cc @@ -8,6 +8,8 @@ #include "../../../plugin/sycl/tree/hist_updater.h" #include "../../../plugin/sycl/device_manager.h" +#include "../../../src/tree/common_row_partitioner.h" + #include "../helpers.h" namespace xgboost::sycl::tree { @@ -59,6 +61,12 @@ class TestHistUpdater : public HistUpdater { HistUpdater::EvaluateSplits(nodes_set, gmat, tree); return HistUpdater::snode_host_; } + + auto TestApplySplit(const std::vector nodes, + const common::GHistIndexMatrix& gmat, + RegTree* p_tree) { + HistUpdater::ApplySplit(nodes, gmat, p_tree); + } }; void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) { @@ -385,6 +393,92 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) { ASSERT_NEAR(best_loss_chg_des[0], best_loss_chg, 1e-6); } +template +void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float sparsity, int max_bins) { + const size_t num_rows = 1024; + const size_t num_columns = 2; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); + sycl::DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + + common::GHistIndexMatrix gmat; + gmat.Init(qu, &ctx, dmat, max_bins); + + RegTree tree; + tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); + + std::vector nodes; + nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0))); + + FeatureInteractionConstraintHost int_constraints; + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); + USMVector gpair(&qu, num_rows); + GenerateRandomGPairs(&qu, gpair.Data(), num_rows, false); + + auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); + updater.TestApplySplit(nodes, gmat, &tree); + + // Copy indexes to host + std::vector row_indices_host(num_rows); + qu.memcpy(row_indices_host.data(), row_set_collection->Data().Data(), sizeof(size_t)*num_rows).wait(); + + // Reference Implementation + std::vector row_indices_desired_host(num_rows); + size_t n_left, n_right; + { + TestHistUpdater updater4verification(&ctx, qu, param, int_constraints, p_fmat.get()); + auto* row_set_collection4verification = updater4verification.TestInitData(gmat, gpair, *p_fmat, tree); + + size_t n_nodes = nodes.size(); + std::vector split_conditions(n_nodes); + xgboost::tree::CommonRowPartitioner::FindSplitConditions(nodes, tree, gmat, &split_conditions); + + common::PartitionBuilder partition_builder; + partition_builder.Init(&qu, n_nodes, [&](size_t node_in_set) { + const int32_t nid = nodes[node_in_set].nid; + return (*row_set_collection4verification)[nid].Size(); + }); + + ::sycl::event event; + partition_builder.Partition(gmat, nodes, (*row_set_collection4verification), + split_conditions, &tree, &event); + qu.wait_and_throw(); + + for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) { + const int32_t nid = nodes[node_in_set].nid; + size_t* data_result = const_cast((*row_set_collection4verification)[nid].begin); + partition_builder.MergeToArray(node_in_set, data_result, &event); + } + qu.wait_and_throw(); + + const int32_t nid = nodes[0].nid; + n_left = partition_builder.GetNLeftElems(0); + n_right = partition_builder.GetNRightElems(0); + + row_set_collection4verification->AddSplit(nid, tree[nid].LeftChild(), + tree[nid].RightChild(), n_left, n_right); + + qu.memcpy(row_indices_desired_host.data(), row_set_collection4verification->Data().Data(), sizeof(size_t)*num_rows).wait(); + } + + std::sort(row_indices_desired_host.begin(), row_indices_desired_host.begin() + n_left); + std::sort(row_indices_host.begin(), row_indices_host.begin() + n_left); + std::sort(row_indices_desired_host.begin() + n_left, row_indices_desired_host.end()); + std::sort(row_indices_host.begin() + n_left, row_indices_host.end()); + + for (size_t row = 0; row < num_rows; ++row) { + ASSERT_EQ(row_indices_desired_host[row], row_indices_host[row]); + } + +} + TEST(SyclHistUpdater, Sampling) { xgboost::tree::TrainParam param; param.UpdateAllowUnknown(Args{{"subsample", "0.7"}}); @@ -432,4 +526,24 @@ TEST(SyclHistUpdater, EvaluateSplits) { TestHistUpdaterEvaluateSplits(param); } +TEST(SyclHistUpdater, ApplySplitSparce) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_depth", "3"}}); + + TestHistUpdaterApplySplit(param, 0.3, 256); + TestHistUpdaterApplySplit(param, 0.3, 256); +} + +TEST(SyclHistUpdater, ApplySplitDence) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_depth", "3"}}); + + TestHistUpdaterApplySplit(param, 0.0, 256); + TestHistUpdaterApplySplit(param, 0.0, 256+1); + TestHistUpdaterApplySplit(param, 0.0, (1u << 16) + 1); + TestHistUpdaterApplySplit(param, 0.0, 256); + TestHistUpdaterApplySplit(param, 0.0, 256+1); + TestHistUpdaterApplySplit(param, 0.0, (1u << 16) + 1); +} + } // namespace xgboost::sycl::tree