Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/cpp/plugin/test_sycl_hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
const USMVector<GradientPair, MemoryType::on_device> &gpair) {
HistUpdater<GradientSumT>::ExpandWithLossGuide(gmat, p_tree, gpair);
}

auto TestExpandWithDepthWise(const common::GHistIndexMatrix& gmat,
DMatrix *p_fmat,
RegTree* p_tree,
const USMVector<GradientPair, MemoryType::on_device> &gpair) {
HistUpdater<GradientSumT>::ExpandWithDepthWise(gmat, p_tree, gpair);
}
};

void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
Expand Down Expand Up @@ -532,6 +539,53 @@ void TestHistUpdaterExpandWithLossGuide(const xgboost::tree::TrainParam& param)
ASSERT_NEAR(ans[2], -0.15, 1e-6);
}

template <typename GradientSumT>
void TestHistUpdaterExpandWithDepthWise(const xgboost::tree::TrainParam& param) {
const size_t num_rows = 3;
const size_t num_columns = 1;
const size_t n_bins = 16;

Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());

std::vector<float> data = {7, 3, 15};
auto p_fmat = GetDMatrixFromData(data, num_rows, num_columns);

DeviceMatrix dmat;
dmat.Init(qu, p_fmat.get());
common::GHistIndexMatrix gmat;
gmat.Init(qu, &ctx, dmat, n_bins);

std::vector<GradientPair> gpair_host = {{1, 2}, {3, 1}, {1, 1}};
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, gpair_host);

RegTree tree;
FeatureInteractionConstraintHost int_constraints;
TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);

updater.TestExpandWithDepthWise(gmat, p_fmat.get(), &tree, gpair);

const auto& nodes = tree.GetNodes();
std::vector<float> ans(data.size());
for (size_t data_idx = 0; data_idx < data.size(); ++data_idx) {
size_t node_idx = 0;
while (!nodes[node_idx].IsLeaf()) {
node_idx = data[data_idx] < nodes[node_idx].SplitCond() ? nodes[node_idx].LeftChild() : nodes[node_idx].RightChild();
}
ans[data_idx] = nodes[node_idx].LeafValue();
}

ASSERT_NEAR(ans[0], -0.15, 1e-6);
ASSERT_NEAR(ans[1], -0.45, 1e-6);
ASSERT_NEAR(ans[2], -0.15, 1e-6);
}

TEST(SyclHistUpdater, Sampling) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});
Expand Down Expand Up @@ -608,4 +662,12 @@ TEST(SyclHistUpdater, ExpandWithLossGuide) {
TestHistUpdaterExpandWithLossGuide<double>(param);
}

TEST(SyclHistUpdater, ExpandWithDepthWise) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_depth", "2"}});

TestHistUpdaterExpandWithDepthWise<float>(param);
TestHistUpdaterExpandWithDepthWise<double>(param);
}

} // namespace xgboost::sycl::tree
Loading