Skip to content

Commit 0def8e0

Browse files
authored
[sycl] fix fitting for fp32 devices (dmlc#10702)
Co-authored-by: Dmitry Razdoburdin <>
1 parent 773ded6 commit 0def8e0

File tree

2 files changed

+44
-16
lines changed

2 files changed

+44
-16
lines changed

plugin/sycl/tree/hist_updater.cc

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -322,23 +322,49 @@ void HistUpdater<GradientSumT>::InitSampling(
322322
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1);
323323
uint64_t seed = seed_;
324324
seed_ += num_rows;
325-
event = qu_.submit([&](::sycl::handler& cgh) {
326-
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
327-
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
328-
[=](::sycl::item<1> pid) {
329-
uint64_t i = pid.get_id(0);
330-
331-
// Create minstd_rand engine
332-
oneapi::dpl::minstd_rand engine(seed, i);
333-
oneapi::dpl::bernoulli_distribution coin_flip(subsample);
334-
335-
auto rnd = coin_flip(engine);
336-
if (gpair_ptr[i].GetHess() >= 0.0f && rnd) {
337-
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
338-
row_idx[num_samples_ref++] = i;
339-
}
325+
326+
/*
327+
* oneDLP bernoulli_distribution implicitly uses double.
328+
* In this case the device doesn't have fp64 support,
329+
* we generate bernoulli distributed random values from uniform distribution
330+
*/
331+
if (has_fp64_support_) {
332+
// Use oneDPL bernoulli_distribution for better perf
333+
event = qu_.submit([&](::sycl::handler& cgh) {
334+
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
335+
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
336+
[=](::sycl::item<1> pid) {
337+
uint64_t i = pid.get_id(0);
338+
// Create minstd_rand engine
339+
oneapi::dpl::minstd_rand engine(seed, i);
340+
oneapi::dpl::bernoulli_distribution coin_flip(subsample);
341+
auto bernoulli_rnd = coin_flip(engine);
342+
343+
if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) {
344+
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
345+
row_idx[num_samples_ref++] = i;
346+
}
347+
});
340348
});
341-
});
349+
} else {
350+
// Use oneDPL uniform, as far as bernoulli_distribution uses fp64
351+
event = qu_.submit([&](::sycl::handler& cgh) {
352+
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
353+
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
354+
[=](::sycl::item<1> pid) {
355+
uint64_t i = pid.get_id(0);
356+
oneapi::dpl::minstd_rand engine(seed, i);
357+
oneapi::dpl::uniform_real_distribution<float> distr;
358+
const float rnd = distr(engine);
359+
const bool bernoulli_rnd = rnd < subsample ? 1 : 0;
360+
361+
if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) {
362+
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
363+
row_idx[num_samples_ref++] = i;
364+
}
365+
});
366+
});
367+
}
342368
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
343369
}
344370

plugin/sycl/tree/hist_updater.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class HistUpdater {
6767
if (param.max_depth > 0) {
6868
snode_device_.Resize(&qu, 1u << (param.max_depth + 1));
6969
}
70+
has_fp64_support_ = qu_.get_device().has(::sycl::aspect::fp64);
7071
const auto sub_group_sizes =
7172
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
7273
sub_group_size_ = sub_group_sizes.back();
@@ -183,6 +184,7 @@ class HistUpdater {
183184

184185
// --data fields--
185186
const Context* ctx_;
187+
bool has_fp64_support_;
186188
size_t sub_group_size_;
187189

188190
// the internal row sets

0 commit comments

Comments
 (0)