Skip to content

Commit fabc355

Browse files
author
Dmitry Razdoburdin
committed
fix test
1 parent 8743b00 commit fabc355

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

tests/cpp/plugin/test_sycl_hist_updater.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
395395

396396
template <typename GradientSumT>
397397
void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float sparsity, int max_bins) {
398-
const size_t num_rows = 16;
399-
const size_t num_columns = 1;
398+
const size_t num_rows = 1024;
399+
const size_t num_columns = 2;
400400

401401
Context ctx;
402402
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
@@ -431,6 +431,7 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
431431

432432
// Reference Implementation
433433
std::vector<size_t> row_indices_desired_host(num_rows);
434+
size_t n_left, n_right;
434435
{
435436
TestHistUpdater<GradientSumT> updater4verification(&ctx, qu, param, int_constraints, p_fmat.get());
436437
auto* row_set_collection4verification = updater4verification.TestInitData(gmat, gpair, *p_fmat, tree);
@@ -457,18 +458,21 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
457458
}
458459
qu.wait_and_throw();
459460

460-
for (size_t i = 0; i < n_nodes; ++i) {
461-
const int32_t nid = nodes[i].nid;
462-
const size_t n_left = partition_builder.GetNLeftElems(i);
463-
const size_t n_right = partition_builder.GetNRightElems(i);
461+
const int32_t nid = nodes[0].nid;
462+
n_left = partition_builder.GetNLeftElems(0);
463+
n_right = partition_builder.GetNRightElems(0);
464464

465-
row_set_collection4verification->AddSplit(nid, tree[nid].LeftChild(),
466-
tree[nid].RightChild(), n_left, n_right);
467-
}
465+
row_set_collection4verification->AddSplit(nid, tree[nid].LeftChild(),
466+
tree[nid].RightChild(), n_left, n_right);
468467

469468
qu.memcpy(row_indices_desired_host.data(), row_set_collection4verification->Data().Data(), sizeof(size_t)*num_rows).wait();
470469
}
471470

471+
std::sort(row_indices_desired_host.begin(), row_indices_desired_host.begin() + n_left);
472+
std::sort(row_indices_host.begin(), row_indices_host.begin() + n_left);
473+
std::sort(row_indices_desired_host.begin() + n_left, row_indices_desired_host.end());
474+
std::sort(row_indices_host.begin() + n_left, row_indices_host.end());
475+
472476
for (size_t row = 0; row < num_rows; ++row) {
473477
ASSERT_EQ(row_indices_desired_host[row], row_indices_host[row]);
474478
}

0 commit comments

Comments
 (0)