Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ jobs:
run: |
conda info
conda list
- name: Display SYCL devices
shell: bash -l {0}
run: |
lscpu
- name: Build and install XGBoost
shell: bash -l {0}
run: |
Expand Down
60 changes: 26 additions & 34 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void HistUpdater<GradientSumT>::EvaluateAndApplySplits(
int *num_leaves,
int depth,
std::vector<ExpandEntry> *temp_qexpand_depth) {
EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree);
EvaluateSplits(qexpand_depth_wise_, gmat, *p_tree);

std::vector<ExpandEntry> nodes_for_apply_split;
AddSplitsToTree(gmat, p_tree, num_leaves, depth,
Expand Down Expand Up @@ -280,7 +280,7 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(

this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, *p_tree);

this->EvaluateSplits({node}, gmat, hist_, *p_tree);
this->EvaluateSplits({node}, gmat, *p_tree);
node.split.loss_chg = snode_host_[ExpandEntry::kRootNid].best.loss_chg;

qexpand_loss_guided_->push(node);
Expand Down Expand Up @@ -325,7 +325,7 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
snode_host_[cleft].weight, snode_host_[cright].weight);
interaction_constraints_.Split(nid, featureid, cleft, cright);

this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree);
this->EvaluateSplits({left_node, right_node}, gmat, *p_tree);
left_node.split.loss_chg = snode_host_[cleft].best.loss_chg;
right_node.split.loss_chg = snode_host_[cright].best.loss_chg;

Expand Down Expand Up @@ -472,7 +472,7 @@ void HistUpdater<GradientSumT>::InitSampling(
});
});
} else {
// Use oneDPL uniform for better perf, as far as bernoulli_distribution uses fp64
// Use oneDPL uniform, as far as bernoulli_distribution uses fp64
event = qu_.submit([&](::sycl::handler& cgh) {
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
Expand Down Expand Up @@ -649,45 +649,32 @@ template<typename GradientSumT>
void HistUpdater<GradientSumT>::EvaluateSplits(
const std::vector<ExpandEntry>& nodes_set,
const common::GHistIndexMatrix& gmat,
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
const RegTree& tree) {
builder_monitor_.Start("EvaluateSplits");

const size_t n_nodes_in_set = nodes_set.size();

using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t>>;
std::vector<FeatureSetType> features_sets(n_nodes_in_set);

// Generate feature set for each tree node
size_t total_features = 0;
for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
const int32_t nid = nodes_set[nid_in_set].nid;
features_sets[nid_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nid));
for (size_t idx = 0; idx < features_sets[nid_in_set]->Size(); idx++) {
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx];
if (interaction_constraints_.Query(nid, fid)) {
total_features++;
}
}
}

split_queries_host_.resize(total_features);
size_t pos = 0;

for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
const size_t nid = nodes_set[nid_in_set].nid;

for (size_t idx = 0; idx < features_sets[nid_in_set]->Size(); idx++) {
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx];
const bst_node_t nid = nodes_set[nid_in_set].nid;
FeatureSetType features_set = column_sampler_->GetFeatureSet(tree.GetDepth(nid));
for (size_t idx = 0; idx < features_set->Size(); idx++) {
const size_t fid = features_set->ConstHostVector()[idx];
if (interaction_constraints_.Query(nid, fid)) {
split_queries_host_[pos].nid = nid;
split_queries_host_[pos].fid = fid;
split_queries_host_[pos].hist = hist[nid].DataConst();
split_queries_host_[pos].best = snode_host_[nid].best;
pos++;
auto this_hist = hist_[nid].DataConst();
if (pos < split_queries_host_.size()) {
split_queries_host_[pos] = SplitQuery{nid, fid, this_hist};
} else {
split_queries_host_.push_back({nid, fid, this_hist});
}
++pos;
}
}
}
const size_t total_features = pos;

split_queries_device_.Resize(&qu_, total_features);
auto event = qu_.memcpy(split_queries_device_.Data(), split_queries_host_.data(),
Expand All @@ -702,10 +689,14 @@ void HistUpdater<GradientSumT>::EvaluateSplits(
snode_device_.ResizeNoCopy(&qu_, snode_host_.size());
event = qu_.memcpy(snode_device_.Data(), snode_host_.data(),
snode_host_.size() * sizeof(NodeEntry<GradientSumT>), event);
const NodeEntry<GradientSumT>* snode = snode_device_.DataConst();
const NodeEntry<GradientSumT>* snode = snode_device_.Data();

const float min_child_weight = param_.min_child_weight;

best_splits_device_.ResizeNoCopy(&qu_, total_features);
if (best_splits_host_.size() < total_features) best_splits_host_.resize(total_features);
SplitEntry<GradientSumT>* best_splits = best_splits_device_.Data();

event = qu_.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, sub_group_size_),
Expand All @@ -717,17 +708,18 @@ void HistUpdater<GradientSumT>::EvaluateSplits(
int fid = split_queries_device[i].fid;
const GradientPairT* hist_data = split_queries_device[i].hist;

best_splits[i] = snode[nid].best;
EnumerateSplit(sg, cut_ptr, cut_val, hist_data, snode[nid],
&(split_queries_device[i].best), fid, nid, evaluator, min_child_weight);
&(best_splits[i]), fid, nid, evaluator, min_child_weight);
});
});
event = qu_.memcpy(split_queries_host_.data(), split_queries_device_.Data(),
total_features * sizeof(SplitQuery), event);
event = qu_.memcpy(best_splits_host_.data(), best_splits,
total_features * sizeof(SplitEntry<GradientSumT>), event);

qu_.wait();
for (size_t i = 0; i < total_features; i++) {
int nid = split_queries_host_[i].nid;
snode_host_[nid].best.Update(split_queries_host_[i].best);
snode_host_[nid].best.Update(best_splits_host_[i]);
}

builder_monitor_.Stop("EvaluateSplits");
Expand Down
9 changes: 5 additions & 4 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ class HistUpdater {
friend class DistributedHistRowsAdder<GradientSumT>;

struct SplitQuery {
int nid;
int fid;
SplitEntry<GradientSumT> best;
bst_node_t nid;
size_t fid;
const GradientPairT* hist;
};

Expand All @@ -106,7 +105,6 @@ class HistUpdater {

void EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
const common::GHistIndexMatrix& gmat,
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
const RegTree& tree);

// Enumerate the split values of specific feature
Expand Down Expand Up @@ -222,6 +220,9 @@ class HistUpdater {
std::vector<SplitQuery> split_queries_host_;
USMVector<SplitQuery, MemoryType::on_device> split_queries_device_;

USMVector<SplitEntry<GradientSumT>, MemoryType::on_device> best_splits_device_;
std::vector<SplitEntry<GradientSumT>> best_splits_host_;

TreeEvaluator<GradientSumT> tree_evaluator_;
FeatureInteractionConstraintHost interaction_constraints_;

Expand Down
2 changes: 1 addition & 1 deletion tests/ci_build/conda_env/linux_sycl_test.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: linux_sycl_test
channels:
- conda-forge
- intel
- https://software.repos.intel.com/python/conda/
dependencies:
- python=3.8
- cmake
Expand Down
92 changes: 92 additions & 0 deletions tests/cpp/plugin/test_sycl_hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
HistUpdater<GradientSumT>::InitNewNode(nid, gmat, gpair, fmat, tree);
return HistUpdater<GradientSumT>::snode_host_[nid];
}

auto TestEvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
const common::GHistIndexMatrix& gmat,
const RegTree& tree) {
HistUpdater<GradientSumT>::EvaluateSplits(nodes_set, gmat, tree);
return HistUpdater<GradientSumT>::snode_host_;
}
};

void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
Expand Down Expand Up @@ -301,6 +308,83 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp
EXPECT_NEAR(snode.stats.GetHess(), grad_stat.GetHess(), 1e-6 * grad_stat.GetHess());
}

template <typename GradientSumT>
void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
const size_t num_rows = 1u << 8;
const size_t num_columns = 2;
const size_t n_bins = 32;

Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};

auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0f}.GenerateDMatrix();

FeatureInteractionConstraintHost int_constraints;

TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());

USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
auto* gpair_ptr = gpair.Data();
GenerateRandomGPairs(&qu, gpair_ptr, num_rows, false);

DeviceMatrix dmat;
dmat.Init(qu, p_fmat.get());
common::GHistIndexMatrix gmat;
gmat.Init(qu, &ctx, dmat, n_bins);

RegTree tree;
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
ExpandEntry node(ExpandEntry::kRootNid, tree.GetDepth(ExpandEntry::kRootNid));

auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);
auto& row_idxs = row_set_collection->Data();
const size_t* row_idxs_ptr = row_idxs.DataConst();
const auto* hist = updater.TestBuildHistogramsLossGuide(node, gmat, &tree, gpair);
const auto snode_init = updater.TestInitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, tree);

const auto snode_updated = updater.TestEvaluateSplits({node}, gmat, tree);
auto best_loss_chg = snode_updated[0].best.loss_chg;
auto stats = snode_init.stats;
auto root_gain = snode_init.root_gain;

// Check all splits manually. Save the best one and compare with the ans
TreeEvaluator<GradientSumT> tree_evaluator(qu, param, num_columns);
auto evaluator = tree_evaluator.GetEvaluator();
const uint32_t* cut_ptr = gmat.cut_device.Ptrs().DataConst();
const size_t size = gmat.cut_device.Ptrs().Size();
int n_better_splits = 0;
const auto* hist_ptr = (*hist)[0].DataConst();
std::vector<bst_float> best_loss_chg_des(1, -1);
{
::sycl::buffer<bst_float> best_loss_chg_buff(best_loss_chg_des.data(), 1);
qu.submit([&](::sycl::handler& cgh) {
auto best_loss_chg_acc = best_loss_chg_buff.template get_access<::sycl::access::mode::read_write>(cgh);
cgh.single_task<>([=]() {
for (size_t i = 1; i < size; ++i) {
GradStats<GradientSumT> left(0, 0);
GradStats<GradientSumT> right = stats - left;
for (size_t j = cut_ptr[i-1]; j < cut_ptr[i]; ++j) {
auto loss_change = evaluator.CalcSplitGain(0, i - 1, left, right) - root_gain;
if (loss_change > best_loss_chg_acc[0]) {
best_loss_chg_acc[0] = loss_change;
}
left.Add(hist_ptr[j].GetGrad(), hist_ptr[j].GetHess());
right = stats - left;
}
}
});
}).wait();
}

ASSERT_NEAR(best_loss_chg_des[0], best_loss_chg, 1e-6);
}

TEST(SyclHistUpdater, Sampling) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});
Expand Down Expand Up @@ -340,4 +424,12 @@ TEST(SyclHistUpdater, InitNewNode) {
TestHistUpdaterInitNewNode<double>(param, 0.5);
}

TEST(SyclHistUpdater, EvaluateSplits) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_depth", "3"}});

TestHistUpdaterEvaluateSplits<float>(param);
TestHistUpdaterEvaluateSplits<double>(param);
}

} // namespace xgboost::sycl::tree