Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 7 additions & 33 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <limits>
#include <vector>

#include "../../src/tree/common_row_partitioner.h"

#include "../common/hist_util.h"
#include "../../src/collective/allreduce.h"

Expand Down Expand Up @@ -188,7 +190,7 @@ void HistUpdater<GradientSumT>::EvaluateAndApplySplits(
std::vector<ExpandEntry> nodes_for_apply_split;
AddSplitsToTree(gmat, p_tree, num_leaves, depth,
&nodes_for_apply_split, temp_qexpand_depth);
ApplySplit(nodes_for_apply_split, gmat, hist_, p_tree);
ApplySplit(nodes_for_apply_split, gmat, p_tree);
}

// Split nodes to 2 sets depending on amount of rows in each node
Expand Down Expand Up @@ -304,7 +306,7 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
right_leaf_weight, e.best.loss_chg, e.stats.GetHess(),
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());

this->ApplySplit({candidate}, gmat, hist_, p_tree);
this->ApplySplit({candidate}, gmat, p_tree);

const int cleft = (*p_tree)[nid].LeftChild();
const int cright = (*p_tree)[nid].RightChild();
Expand Down Expand Up @@ -786,34 +788,6 @@ void HistUpdater<GradientSumT>::EnumerateSplit(
best.SplitIndex() == total_split_index) p_best->Update(best);
}

template <typename GradientSumT>
void HistUpdater<GradientSumT>::FindSplitConditions(
const std::vector<ExpandEntry>& nodes,
const RegTree& tree,
const common::GHistIndexMatrix& gmat,
std::vector<int32_t>* split_conditions) {
const size_t n_nodes = nodes.size();
split_conditions->resize(n_nodes);

for (size_t i = 0; i < nodes.size(); ++i) {
const int32_t nid = nodes[i].nid;
const bst_uint fid = tree[nid].SplitIndex();
const bst_float split_pt = tree[nid].SplitCond();
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
int32_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
if (split_pt == gmat.cut.Values()[i]) {
split_cond = static_cast<int32_t>(i);
}
}
(*split_conditions)[i] = split_cond;
}
}
template <typename GradientSumT>
void HistUpdater<GradientSumT>::AddSplitsToRowSet(
const std::vector<ExpandEntry>& nodes,
Expand All @@ -833,13 +807,13 @@ template <typename GradientSumT>
void HistUpdater<GradientSumT>::ApplySplit(
const std::vector<ExpandEntry> nodes,
const common::GHistIndexMatrix& gmat,
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
RegTree* p_tree) {
using CommonRowPartitioner = xgboost::tree::CommonRowPartitioner;
builder_monitor_.Start("ApplySplit");

const size_t n_nodes = nodes.size();
std::vector<int32_t> split_conditions;
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
std::vector<int32_t> split_conditions(n_nodes);
CommonRowPartitioner::FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);

partition_builder_.Init(&qu_, n_nodes, [&](size_t node_in_set) {
const int32_t nid = nodes[node_in_set].nid;
Expand Down
6 changes: 0 additions & 6 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,10 @@ class HistUpdater {

void ApplySplit(std::vector<ExpandEntry> nodes,
const common::GHistIndexMatrix& gmat,
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
RegTree* p_tree);

void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree* p_tree);


void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const common::GHistIndexMatrix& gmat,
std::vector<int32_t>* split_conditions);

void InitData(const common::GHistIndexMatrix& gmat,
const USMVector<GradientPair, MemoryType::on_device> &gpair,
const DMatrix& fmat,
Expand Down
7 changes: 4 additions & 3 deletions src/tree/common_row_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ class CommonRowPartitioner {
}
}

template <typename ExpandEntry>
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
/* Making GHistIndexMatrix_t a templete parameter allows reuse this function for sycl-plugin */
template <typename ExpandEntry, typename GHistIndexMatrix_t>
static void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix_t& gmat, std::vector<int32_t>* split_conditions) {
auto const& ptrs = gmat.cut.Ptrs();
auto const& vals = gmat.cut.Values();

Expand Down
114 changes: 114 additions & 0 deletions tests/cpp/plugin/test_sycl_hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "../../../plugin/sycl/tree/hist_updater.h"
#include "../../../plugin/sycl/device_manager.h"

#include "../../../src/tree/common_row_partitioner.h"

#include "../helpers.h"

namespace xgboost::sycl::tree {
Expand Down Expand Up @@ -59,6 +61,12 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
HistUpdater<GradientSumT>::EvaluateSplits(nodes_set, gmat, tree);
return HistUpdater<GradientSumT>::snode_host_;
}

auto TestApplySplit(const std::vector<ExpandEntry> nodes,
const common::GHistIndexMatrix& gmat,
RegTree* p_tree) {
HistUpdater<GradientSumT>::ApplySplit(nodes, gmat, p_tree);
}
};

void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
Expand Down Expand Up @@ -385,6 +393,92 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
ASSERT_NEAR(best_loss_chg_des[0], best_loss_chg, 1e-6);
}

template <typename GradientSumT>
void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float sparsity, int max_bins) {
const size_t num_rows = 1024;
const size_t num_columns = 2;

Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());

auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
sycl::DeviceMatrix dmat;
dmat.Init(qu, p_fmat.get());

common::GHistIndexMatrix gmat;
gmat.Init(qu, &ctx, dmat, max_bins);

RegTree tree;
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);

std::vector<tree::ExpandEntry> nodes;
nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0)));

FeatureInteractionConstraintHost int_constraints;
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
GenerateRandomGPairs(&qu, gpair.Data(), num_rows, false);

auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);
updater.TestApplySplit(nodes, gmat, &tree);

// Copy indexes to host
std::vector<size_t> row_indices_host(num_rows);
qu.memcpy(row_indices_host.data(), row_set_collection->Data().Data(), sizeof(size_t)*num_rows).wait();

// Reference Implementation
std::vector<size_t> row_indices_desired_host(num_rows);
size_t n_left, n_right;
{
TestHistUpdater<GradientSumT> updater4verification(&ctx, qu, param, int_constraints, p_fmat.get());
auto* row_set_collection4verification = updater4verification.TestInitData(gmat, gpair, *p_fmat, tree);

size_t n_nodes = nodes.size();
std::vector<int32_t> split_conditions(n_nodes);
xgboost::tree::CommonRowPartitioner::FindSplitConditions(nodes, tree, gmat, &split_conditions);

common::PartitionBuilder partition_builder;
partition_builder.Init(&qu, n_nodes, [&](size_t node_in_set) {
const int32_t nid = nodes[node_in_set].nid;
return (*row_set_collection4verification)[nid].Size();
});

::sycl::event event;
partition_builder.Partition(gmat, nodes, (*row_set_collection4verification),
split_conditions, &tree, &event);
qu.wait_and_throw();

for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) {
const int32_t nid = nodes[node_in_set].nid;
size_t* data_result = const_cast<size_t*>((*row_set_collection4verification)[nid].begin);
partition_builder.MergeToArray(node_in_set, data_result, &event);
}
qu.wait_and_throw();

const int32_t nid = nodes[0].nid;
n_left = partition_builder.GetNLeftElems(0);
n_right = partition_builder.GetNRightElems(0);

row_set_collection4verification->AddSplit(nid, tree[nid].LeftChild(),
tree[nid].RightChild(), n_left, n_right);

qu.memcpy(row_indices_desired_host.data(), row_set_collection4verification->Data().Data(), sizeof(size_t)*num_rows).wait();
}

std::sort(row_indices_desired_host.begin(), row_indices_desired_host.begin() + n_left);
std::sort(row_indices_host.begin(), row_indices_host.begin() + n_left);
std::sort(row_indices_desired_host.begin() + n_left, row_indices_desired_host.end());
std::sort(row_indices_host.begin() + n_left, row_indices_host.end());

for (size_t row = 0; row < num_rows; ++row) {
ASSERT_EQ(row_indices_desired_host[row], row_indices_host[row]);
}

}

TEST(SyclHistUpdater, Sampling) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});
Expand Down Expand Up @@ -432,4 +526,24 @@ TEST(SyclHistUpdater, EvaluateSplits) {
TestHistUpdaterEvaluateSplits<double>(param);
}

TEST(SyclHistUpdater, ApplySplitSparce) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_depth", "3"}});

TestHistUpdaterApplySplit<float>(param, 0.3, 256);
TestHistUpdaterApplySplit<double>(param, 0.3, 256);
}

TEST(SyclHistUpdater, ApplySplitDence) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_depth", "3"}});

TestHistUpdaterApplySplit<float>(param, 0.0, 256);
TestHistUpdaterApplySplit<float>(param, 0.0, 256+1);
TestHistUpdaterApplySplit<float>(param, 0.0, (1u << 16) + 1);
TestHistUpdaterApplySplit<double>(param, 0.0, 256);
TestHistUpdaterApplySplit<double>(param, 0.0, 256+1);
TestHistUpdaterApplySplit<double>(param, 0.0, (1u << 16) + 1);
}

} // namespace xgboost::sycl::tree
Loading