Skip to content

Commit d1d1f3c

Browse files
author
Dmitry Razdoburdin
committed
reuse FindSplitConditons from cpu branch
1 parent 916d4a4 commit d1d1f3c

File tree

4 files changed

+17
-38
lines changed

4 files changed

+17
-38
lines changed

plugin/sycl/tree/hist_updater.cc

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <limits>
1414
#include <vector>
1515

16+
#include "../../src/tree/common_row_partitioner.h"
17+
1618
#include "../common/hist_util.h"
1719
#include "../../src/collective/allreduce.h"
1820

@@ -786,34 +788,6 @@ void HistUpdater<GradientSumT>::EnumerateSplit(
786788
best.SplitIndex() == total_split_index) p_best->Update(best);
787789
}
788790

789-
template <typename GradientSumT>
790-
void HistUpdater<GradientSumT>::FindSplitConditions(
791-
const std::vector<ExpandEntry>& nodes,
792-
const RegTree& tree,
793-
const common::GHistIndexMatrix& gmat,
794-
std::vector<int32_t>* split_conditions) {
795-
const size_t n_nodes = nodes.size();
796-
split_conditions->resize(n_nodes);
797-
798-
for (size_t i = 0; i < nodes.size(); ++i) {
799-
const int32_t nid = nodes[i].nid;
800-
const bst_uint fid = tree[nid].SplitIndex();
801-
const bst_float split_pt = tree[nid].SplitCond();
802-
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
803-
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
804-
int32_t split_cond = -1;
805-
// convert floating-point split_pt into corresponding bin_id
806-
// split_cond = -1 indicates that split_pt is less than all known cut points
807-
CHECK_LT(upper_bound,
808-
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
809-
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
810-
if (split_pt == gmat.cut.Values()[i]) {
811-
split_cond = static_cast<int32_t>(i);
812-
}
813-
}
814-
(*split_conditions)[i] = split_cond;
815-
}
816-
}
817791
template <typename GradientSumT>
818792
void HistUpdater<GradientSumT>::AddSplitsToRowSet(
819793
const std::vector<ExpandEntry>& nodes,
@@ -835,11 +809,12 @@ void HistUpdater<GradientSumT>::ApplySplit(
835809
const common::GHistIndexMatrix& gmat,
836810
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
837811
RegTree* p_tree) {
812+
using CommonRowPartitioner = xgboost::tree::CommonRowPartitioner;
838813
builder_monitor_.Start("ApplySplit");
839814

840815
const size_t n_nodes = nodes.size();
841-
std::vector<int32_t> split_conditions;
842-
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
816+
std::vector<int32_t> split_conditions(n_nodes);
817+
CommonRowPartitioner::FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
843818

844819
partition_builder_.Init(&qu_, n_nodes, [&](size_t node_in_set) {
845820
const int32_t nid = nodes[node_in_set].nid;

plugin/sycl/tree/hist_updater.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,6 @@ class HistUpdater {
124124

125125
void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree* p_tree);
126126

127-
128-
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
129-
const common::GHistIndexMatrix& gmat,
130-
std::vector<int32_t>* split_conditions);
131-
132127
void InitData(const common::GHistIndexMatrix& gmat,
133128
const USMVector<GradientPair, MemoryType::on_device> &gpair,
134129
const DMatrix& fmat,

src/tree/common_row_partitioner.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,10 @@ class CommonRowPartitioner {
110110
}
111111
}
112112

113-
template <typename ExpandEntry>
114-
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
115-
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
113+
/* Making GHistIndexMatrix_t a templete parameter allows reuse this function for sycl-plugin */
114+
template <typename ExpandEntry, typename GHistIndexMatrix_t>
115+
void static FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
116+
const GHistIndexMatrix_t& gmat, std::vector<int32_t>* split_conditions) {
116117
auto const& ptrs = gmat.cut.Ptrs();
117118
auto const& vals = gmat.cut.Values();
118119

tests/cpp/plugin/test_sycl_hist_updater.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,4 +432,12 @@ TEST(SyclHistUpdater, EvaluateSplits) {
432432
TestHistUpdaterEvaluateSplits<double>(param);
433433
}
434434

435+
TEST(SyclHistUpdater, ApplySplit) {
436+
xgboost::tree::TrainParam param;
437+
param.UpdateAllowUnknown(Args{{"max_depth", "3"}});
438+
439+
// TestHistUpdaterApplySplit<float>(param);
440+
// TestHistUpdaterApplySplit<double>(param);
441+
}
442+
435443
} // namespace xgboost::sycl::tree

0 commit comments

Comments
 (0)