Skip to content

Commit 0058301

Browse files
authored
[sycl] optimise hist building (dmlc#10311)
Co-authored-by: Dmitry Razdoburdin <>
1 parent 9def441 commit 0058301

File tree

3 files changed

+54
-38
lines changed

3 files changed

+54
-38
lines changed

plugin/sycl/common/hist_util.cc

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,30 @@ template ::sycl::event SubtractionHist(::sycl::queue qu,
6464
const GHistRow<double, MemoryType::on_device>& src2,
6565
size_t size, ::sycl::event event_priv);
6666

67+
inline auto GetBlocksParameters(const ::sycl::queue& qu, size_t size, size_t max_nblocks) {
68+
struct _ {
69+
size_t block_size, nblocks;
70+
};
71+
72+
const size_t min_block_size = 32;
73+
const size_t max_compute_units =
74+
qu.get_device().get_info<::sycl::info::device::max_compute_units>();
75+
76+
size_t nblocks = max_compute_units;
77+
78+
size_t block_size = size / nblocks + !!(size % nblocks);
79+
if (block_size > (1u << 12)) {
80+
nblocks = max_nblocks;
81+
block_size = size / nblocks + !!(size % nblocks);
82+
}
83+
if (block_size < min_block_size) {
84+
block_size = min_block_size;
85+
nblocks = size / block_size + !!(size % block_size);
86+
}
87+
88+
return _{block_size, nblocks};
89+
}
90+
6791
// Kernel with buffer using
6892
template<typename FPType, typename BinIdxType, bool isDense>
6993
::sycl::event BuildHistKernel(::sycl::queue qu,
@@ -73,27 +97,26 @@ ::sycl::event BuildHistKernel(::sycl::queue qu,
7397
GHistRow<FPType, MemoryType::on_device>* hist,
7498
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
7599
::sycl::event event_priv) {
100+
using GradientPairT = xgboost::detail::GradientPairInternal<FPType>;
76101
const size_t size = row_indices.Size();
77102
const size_t* rid = row_indices.begin;
78103
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
79-
const GradientPair::ValueT* pgh =
80-
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
104+
const auto* pgh = gpair_device.DataConst();
81105
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
82106
const uint32_t* offsets = gmat.index.Offset();
83-
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
84107
const size_t nbins = gmat.nbins;
85108

86109
const size_t max_work_group_size =
87110
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
88111
const size_t work_group_size = n_columns < max_work_group_size ? n_columns : max_work_group_size;
89112

90-
const size_t max_nblocks = hist_buffer->Size() / (nbins * 2);
91-
const size_t min_block_size = 128;
92-
size_t nblocks = std::min(max_nblocks, size / min_block_size + !!(size % min_block_size));
93-
const size_t block_size = size / nblocks + !!(size % nblocks);
94-
FPType* hist_buffer_data = reinterpret_cast<FPType*>(hist_buffer->Data());
113+
// Captured structured bindings are a C++20 extension
114+
const auto block_params = GetBlocksParameters(qu, size, hist_buffer->Size() / (nbins * 2));
115+
const size_t block_size = block_params.block_size;
116+
const size_t nblocks = block_params.nblocks;
95117

96-
auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv);
118+
GradientPairT* hist_buffer_data = hist_buffer->Data();
119+
auto event_fill = qu.fill(hist_buffer_data, GradientPairT(0, 0), nblocks * nbins * 2, event_priv);
97120
auto event_main = qu.submit([&](::sycl::handler& cgh) {
98121
cgh.depends_on(event_fill);
99122
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(nblocks, work_group_size),
@@ -102,13 +125,14 @@ ::sycl::event BuildHistKernel(::sycl::queue qu,
102125
size_t block = pid.get_global_id(0);
103126
size_t feat = pid.get_global_id(1);
104127

105-
FPType* hist_local = hist_buffer_data + block * nbins * 2;
128+
GradientPairT* hist_local = hist_buffer_data + block * nbins;
106129
for (size_t idx = 0; idx < block_size; ++idx) {
107130
size_t i = block * block_size + idx;
108131
if (i < size) {
109132
const size_t icol_start = n_columns * rid[i];
110133
const size_t idx_gh = rid[i];
111134

135+
const GradientPairT pgh_row = {pgh[idx_gh].GetGrad(), pgh[idx_gh].GetHess()};
112136
pid.barrier(::sycl::access::fence_space::local_space);
113137
const BinIdxType* gr_index_local = gradient_index + icol_start;
114138

@@ -118,30 +142,27 @@ ::sycl::event BuildHistKernel(::sycl::queue qu,
118142
idx_bin += offsets[j];
119143
}
120144
if (idx_bin < nbins) {
121-
hist_local[2 * idx_bin] += pgh[2 * idx_gh];
122-
hist_local[2 * idx_bin+1] += pgh[2 * idx_gh+1];
145+
hist_local[idx_bin] += pgh_row;
123146
}
124147
}
125148
}
126149
}
127150
});
128151
});
129152

153+
GradientPairT* hist_data = hist->Data();
130154
auto event_save = qu.submit([&](::sycl::handler& cgh) {
131155
cgh.depends_on(event_main);
132156
cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) {
133157
size_t idx_bin = pid.get_id(0);
134158

135-
FPType gsum = 0.0f;
136-
FPType hsum = 0.0f;
159+
GradientPairT gpair = {0, 0};
137160

138161
for (size_t j = 0; j < nblocks; ++j) {
139-
gsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin];
140-
hsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin + 1];
162+
gpair += hist_buffer_data[j * nbins + idx_bin];
141163
}
142164

143-
hist_data[2 * idx_bin] = gsum;
144-
hist_data[2 * idx_bin + 1] = hsum;
165+
hist_data[idx_bin] = gpair;
145166
});
146167
});
147168
return event_save;
@@ -165,33 +186,36 @@ ::sycl::event BuildHistKernel(::sycl::queue qu,
165186
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
166187
const size_t nbins = gmat.nbins;
167188

168-
const size_t max_work_group_size =
169-
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
170-
const size_t feat_local = n_columns < max_work_group_size ? n_columns : max_work_group_size;
189+
constexpr size_t work_group_size = 32;
190+
const size_t n_work_groups = n_columns / work_group_size + (n_columns % work_group_size > 0);
171191

172192
auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv);
173193
auto event_main = qu.submit([&](::sycl::handler& cgh) {
174194
cgh.depends_on(event_fill);
175-
cgh.parallel_for<>(::sycl::range<2>(size, feat_local),
176-
[=](::sycl::item<2> pid) {
177-
size_t i = pid.get_id(0);
178-
size_t feat = pid.get_id(1);
195+
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(size, n_work_groups * work_group_size),
196+
::sycl::range<2>(1, work_group_size)),
197+
[=](::sycl::nd_item<2> pid) {
198+
const int i = pid.get_global_id(0);
199+
auto group = pid.get_group();
179200

180201
const size_t icol_start = n_columns * rid[i];
181202
const size_t idx_gh = rid[i];
182-
203+
const FPType pgh_row[2] = {pgh[2 * idx_gh], pgh[2 * idx_gh + 1]};
183204
const BinIdxType* gr_index_local = gradient_index + icol_start;
184205

185-
for (size_t j = feat; j < n_columns; j += feat_local) {
206+
const size_t group_id = group.get_group_id()[1];
207+
const size_t local_id = group.get_local_id()[1];
208+
const size_t j = group_id * work_group_size + local_id;
209+
if (j < n_columns) {
186210
uint32_t idx_bin = static_cast<uint32_t>(gr_index_local[j]);
187211
if constexpr (isDense) {
188212
idx_bin += offsets[j];
189213
}
190214
if (idx_bin < nbins) {
191215
AtomicRef<FPType> gsum(hist_data[2 * idx_bin]);
192216
AtomicRef<FPType> hsum(hist_data[2 * idx_bin + 1]);
193-
gsum.fetch_add(pgh[2 * idx_gh]);
194-
hsum.fetch_add(pgh[2 * idx_gh + 1]);
217+
gsum += pgh_row[0];
218+
hsum += pgh_row[1];
195219
}
196220
}
197221
});
@@ -214,10 +238,6 @@ ::sycl::event BuildHistDispatchKernel(
214238
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
215239
const size_t nbins = gmat.nbins;
216240

217-
// max cycle size, while atomics are still effective
218-
const size_t max_cycle_size_atomics = nbins;
219-
const size_t cycle_size = size;
220-
221241
// TODO(razdoburdin): replace the add-hock dispatching criteria by more sutable one
222242
bool use_atomic = (size < nbins) || (gmat.max_num_bins == gmat.nbins / n_columns);
223243

plugin/sycl/tree/hist_updater.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,7 @@ void HistUpdater<GradientSumT>::InitData(
136136

137137
hist_buffer_.Init(qu_, nbins);
138138
size_t buffer_size = kBufferSize;
139-
if (buffer_size > info.num_row_ / kMinBlockSize + 1) {
140-
buffer_size = info.num_row_ / kMinBlockSize + 1;
141-
}
142-
hist_buffer_.Reset(buffer_size);
139+
hist_buffer_.Reset(kBufferSize);
143140

144141
// initialize histogram builder
145142
hist_builder_ = common::GHistBuilder<GradientSumT>(qu_, nbins);

plugin/sycl/tree/hist_updater.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ class HistUpdater {
130130
DataLayout data_layout_;
131131

132132
constexpr static size_t kBufferSize = 2048;
133-
constexpr static size_t kMinBlockSize = 128;
134133
common::GHistBuilder<GradientSumT> hist_builder_;
135134
common::ParallelGHistBuilder<GradientSumT> hist_buffer_;
136135
/*! \brief culmulative histogram of gradients. */

0 commit comments

Comments
 (0)