@@ -322,23 +322,49 @@ void HistUpdater<GradientSumT>::InitSampling(
322
322
::sycl::buffer<uint64_t , 1 > flag_buf (&num_samples, 1 );
323
323
uint64_t seed = seed_;
324
324
seed_ += num_rows;
325
- event = qu_.submit ([&](::sycl::handler& cgh) {
326
- auto flag_buf_acc = flag_buf.get_access <::sycl::access::mode::read_write>(cgh);
327
- cgh.parallel_for <>(::sycl::range<1 >(::sycl::range<1 >(num_rows)),
328
- [=](::sycl::item<1 > pid) {
329
- uint64_t i = pid.get_id (0 );
330
-
331
- // Create minstd_rand engine
332
- oneapi::dpl::minstd_rand engine (seed, i);
333
- oneapi::dpl::bernoulli_distribution coin_flip (subsample);
334
-
335
- auto rnd = coin_flip (engine);
336
- if (gpair_ptr[i].GetHess () >= 0 .0f && rnd) {
337
- AtomicRef<uint64_t > num_samples_ref (flag_buf_acc[0 ]);
338
- row_idx[num_samples_ref++] = i;
339
- }
325
+
326
+ /*
327
+ * oneDLP bernoulli_distribution implicitly uses double.
328
+ * In this case the device doesn't have fp64 support,
329
+ * we generate bernoulli distributed random values from uniform distribution
330
+ */
331
+ if (has_fp64_support_) {
332
+ // Use oneDPL bernoulli_distribution for better perf
333
+ event = qu_.submit ([&](::sycl::handler& cgh) {
334
+ auto flag_buf_acc = flag_buf.get_access <::sycl::access::mode::read_write>(cgh);
335
+ cgh.parallel_for <>(::sycl::range<1 >(::sycl::range<1 >(num_rows)),
336
+ [=](::sycl::item<1 > pid) {
337
+ uint64_t i = pid.get_id (0 );
338
+ // Create minstd_rand engine
339
+ oneapi::dpl::minstd_rand engine (seed, i);
340
+ oneapi::dpl::bernoulli_distribution coin_flip (subsample);
341
+ auto bernoulli_rnd = coin_flip (engine);
342
+
343
+ if (gpair_ptr[i].GetHess () >= 0 .0f && bernoulli_rnd) {
344
+ AtomicRef<uint64_t > num_samples_ref (flag_buf_acc[0 ]);
345
+ row_idx[num_samples_ref++] = i;
346
+ }
347
+ });
340
348
});
341
- });
349
+ } else {
350
+ // Use oneDPL uniform, as far as bernoulli_distribution uses fp64
351
+ event = qu_.submit ([&](::sycl::handler& cgh) {
352
+ auto flag_buf_acc = flag_buf.get_access <::sycl::access::mode::read_write>(cgh);
353
+ cgh.parallel_for <>(::sycl::range<1 >(::sycl::range<1 >(num_rows)),
354
+ [=](::sycl::item<1 > pid) {
355
+ uint64_t i = pid.get_id (0 );
356
+ oneapi::dpl::minstd_rand engine (seed, i);
357
+ oneapi::dpl::uniform_real_distribution<float > distr;
358
+ const float rnd = distr (engine);
359
+ const bool bernoulli_rnd = rnd < subsample ? 1 : 0 ;
360
+
361
+ if (gpair_ptr[i].GetHess () >= 0 .0f && bernoulli_rnd) {
362
+ AtomicRef<uint64_t > num_samples_ref (flag_buf_acc[0 ]);
363
+ row_idx[num_samples_ref++] = i;
364
+ }
365
+ });
366
+ });
367
+ }
342
368
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
343
369
}
344
370
0 commit comments