@@ -395,8 +395,8 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
395
395
396
396
template <typename GradientSumT>
397
397
void TestHistUpdaterApplySplit (const xgboost::tree::TrainParam& param, float sparsity, int max_bins) {
398
- const size_t num_rows = 16 ;
399
- const size_t num_columns = 1 ;
398
+ const size_t num_rows = 1024 ;
399
+ const size_t num_columns = 2 ;
400
400
401
401
Context ctx;
402
402
ctx.UpdateAllowUnknown (Args{{" device" , " sycl" }});
@@ -431,6 +431,7 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
431
431
432
432
// Reference Implementation
433
433
std::vector<size_t > row_indices_desired_host (num_rows);
434
+ size_t n_left, n_right;
434
435
{
435
436
TestHistUpdater<GradientSumT> updater4verification (&ctx, qu, param, int_constraints, p_fmat.get ());
436
437
auto * row_set_collection4verification = updater4verification.TestInitData (gmat, gpair, *p_fmat, tree);
@@ -457,18 +458,21 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
457
458
}
458
459
qu.wait_and_throw ();
459
460
460
- for (size_t i = 0 ; i < n_nodes; ++i) {
461
- const int32_t nid = nodes[i].nid ;
462
- const size_t n_left = partition_builder.GetNLeftElems (i);
463
- const size_t n_right = partition_builder.GetNRightElems (i);
461
+ const int32_t nid = nodes[0 ].nid ;
462
+ n_left = partition_builder.GetNLeftElems (0 );
463
+ n_right = partition_builder.GetNRightElems (0 );
464
464
465
- row_set_collection4verification->AddSplit (nid, tree[nid].LeftChild (),
466
- tree[nid].RightChild (), n_left, n_right);
467
- }
465
+ row_set_collection4verification->AddSplit (nid, tree[nid].LeftChild (),
466
+ tree[nid].RightChild (), n_left, n_right);
468
467
469
468
qu.memcpy (row_indices_desired_host.data (), row_set_collection4verification->Data ().Data (), sizeof (size_t )*num_rows).wait ();
470
469
}
471
470
471
+ std::sort (row_indices_desired_host.begin (), row_indices_desired_host.begin () + n_left);
472
+ std::sort (row_indices_host.begin (), row_indices_host.begin () + n_left);
473
+ std::sort (row_indices_desired_host.begin () + n_left, row_indices_desired_host.end ());
474
+ std::sort (row_indices_host.begin () + n_left, row_indices_host.end ());
475
+
472
476
for (size_t row = 0 ; row < num_rows; ++row) {
473
477
ASSERT_EQ (row_indices_desired_host[row], row_indices_host[row]);
474
478
}
0 commit comments