Skip to content

Commit 34d4ab4

Browse files
authored
[EM] Avoid stream sync in quantile sketching. (dmlc#10765)
.
1 parent 61dd854 commit 34d4ab4

12 files changed

+313
-313
lines changed

src/common/algorithm.cuh

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2022-2023 by XGBoost Contributors
2+
* Copyright 2022-2024, XGBoost Contributors
33
*/
44
#ifndef XGBOOST_COMMON_ALGORITHM_CUH_
55
#define XGBOOST_COMMON_ALGORITHM_CUH_
@@ -258,5 +258,19 @@ void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
258258
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
259259
cuctx->Stream()));
260260
}
261+
262+
template <typename InIt, typename OutIt, typename Predicate>
263+
void CopyIf(CUDAContext const *cuctx, InIt in_first, InIt in_second, OutIt out_first,
264+
Predicate pred) {
265+
// We loop over batches because thrust::copy_if can't deal with sizes > 2^31
266+
// See thrust issue #1302, XGBoost #6822
267+
size_t constexpr kMaxCopySize = std::numeric_limits<int>::max() / 2;
268+
size_t length = std::distance(in_first, in_second);
269+
for (size_t offset = 0; offset < length; offset += kMaxCopySize) {
270+
auto begin_input = in_first + offset;
271+
auto end_input = in_first + std::min(offset + kMaxCopySize, length);
272+
out_first = thrust::copy_if(cuctx->CTP(), begin_input, end_input, out_first, pred);
273+
}
274+
}
261275
} // namespace xgboost::common
262276
#endif // XGBOOST_COMMON_ALGORITHM_CUH_

src/common/device_helpers.cuh

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -637,12 +637,11 @@ struct SegmentedUniqueReduceOp {
637637
* \return Number of unique values in total.
638638
*/
639639
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
640-
typename ValOutIt, typename CompValue, typename CompKey>
641-
size_t
642-
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
643-
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
644-
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
645-
CompValue comp, CompKey comp_key=thrust::equal_to<size_t>{}) {
640+
typename ValOutIt, typename CompValue, typename CompKey = thrust::equal_to<size_t>>
641+
size_t SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
642+
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
643+
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
644+
CompValue comp, CompKey comp_key = thrust::equal_to<size_t>{}) {
646645
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
647646
auto unique_key_it = dh::MakeTransformIterator<Key>(
648647
thrust::make_counting_iterator(static_cast<size_t>(0)),
@@ -676,16 +675,6 @@ SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec
676675
return n_uniques;
677676
}
678677

679-
template <typename... Inputs,
680-
std::enable_if_t<std::tuple_size<std::tuple<Inputs...>>::value == 7>
681-
* = nullptr>
682-
size_t SegmentedUnique(Inputs &&...inputs) {
683-
dh::XGBCachingDeviceAllocator<char> alloc;
684-
return SegmentedUnique(thrust::cuda::par(alloc),
685-
std::forward<Inputs &&>(inputs)...,
686-
thrust::equal_to<size_t>{});
687-
}
688-
689678
/**
690679
* \brief Unique by key for many groups of data. Has same constraint as `SegmentedUnique`.
691680
*
@@ -793,21 +782,6 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
793782
#endif
794783
}
795784

796-
template <typename InIt, typename OutIt, typename Predicate>
797-
void CopyIf(InIt in_first, InIt in_second, OutIt out_first, Predicate pred) {
798-
// We loop over batches because thrust::copy_if can't deal with sizes > 2^31
799-
// See thrust issue #1302, XGBoost #6822
800-
size_t constexpr kMaxCopySize = std::numeric_limits<int>::max() / 2;
801-
size_t length = std::distance(in_first, in_second);
802-
XGBCachingDeviceAllocator<char> alloc;
803-
for (size_t offset = 0; offset < length; offset += kMaxCopySize) {
804-
auto begin_input = in_first + offset;
805-
auto end_input = in_first + std::min(offset + kMaxCopySize, length);
806-
out_first = thrust::copy_if(thrust::cuda::par(alloc), begin_input,
807-
end_input, out_first, pred);
808-
}
809-
}
810-
811785
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
812786
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
813787
InclusiveScan(d_in, d_out, cub::Sum(), num_items);

src/common/hist_util.cu

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -106,26 +106,27 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_ro
106106
return std::min(sketch_batch_num_elements, kIntMax);
107107
}
108108

109-
void SortByWeight(dh::device_vector<float>* weights, dh::device_vector<Entry>* sorted_entries) {
109+
void SortByWeight(Context const* ctx, dh::device_vector<float>* weights,
110+
dh::device_vector<Entry>* sorted_entries) {
110111
// Sort both entries and wegihts.
111-
dh::XGBDeviceAllocator<char> alloc;
112+
auto cuctx = ctx->CUDACtx();
112113
CHECK_EQ(weights->size(), sorted_entries->size());
113-
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(),
114-
weights->begin(), detail::EntryCompareOp());
114+
thrust::sort_by_key(cuctx->TP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
115+
detail::EntryCompareOp());
115116

116117
// Scan weights
117-
dh::XGBCachingDeviceAllocator<char> caching;
118118
thrust::inclusive_scan_by_key(
119-
thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
119+
cuctx->CTP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
120120
weights->begin(),
121121
[=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; });
122122
}
123123

124-
void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_idx_t> d_cuts_ptr,
124+
void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
125+
Span<bst_idx_t> d_cuts_ptr,
125126
dh::device_vector<Entry>* p_sorted_entries,
126127
dh::device_vector<float>* p_sorted_weights,
127128
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
128-
info.feature_types.SetDevice(device);
129+
info.feature_types.SetDevice(ctx->Device());
129130
auto d_feature_types = info.feature_types.ConstDeviceSpan();
130131
CHECK(!d_feature_types.empty());
131132
auto& column_sizes_scan = *p_column_sizes_scan;
@@ -142,30 +143,32 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
142143
auto d_sorted_weights = dh::ToSpan(*p_sorted_weights);
143144
auto val_in_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
144145
auto val_out_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
145-
n_uniques = dh::SegmentedUnique(
146-
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
147-
val_in_it, val_in_it + sorted_entries.size(), new_column_scan.data().get(), val_out_it,
148-
[=] __device__(Pair const& l, Pair const& r) {
149-
Entry const& le = thrust::get<0>(l);
150-
Entry const& re = thrust::get<0>(r);
151-
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
152-
return le.fvalue == re.fvalue;
153-
}
154-
return false;
155-
});
146+
n_uniques =
147+
dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(),
148+
column_sizes_scan.data().get() + column_sizes_scan.size(), val_in_it,
149+
val_in_it + sorted_entries.size(), new_column_scan.data().get(),
150+
val_out_it, [=] __device__(Pair const& l, Pair const& r) {
151+
Entry const& le = thrust::get<0>(l);
152+
Entry const& re = thrust::get<0>(r);
153+
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
154+
return le.fvalue == re.fvalue;
155+
}
156+
return false;
157+
});
156158
p_sorted_weights->resize(n_uniques);
157159
} else {
158-
n_uniques = dh::SegmentedUnique(
159-
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
160-
sorted_entries.begin(), sorted_entries.end(), new_column_scan.data().get(),
161-
sorted_entries.begin(), [=] __device__(Entry const& l, Entry const& r) {
162-
if (l.index == r.index) {
163-
if (IsCat(d_feature_types, l.index)) {
164-
return l.fvalue == r.fvalue;
165-
}
166-
}
167-
return false;
168-
});
160+
n_uniques = dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(),
161+
column_sizes_scan.data().get() + column_sizes_scan.size(),
162+
sorted_entries.begin(), sorted_entries.end(),
163+
new_column_scan.data().get(), sorted_entries.begin(),
164+
[=] __device__(Entry const& l, Entry const& r) {
165+
if (l.index == r.index) {
166+
if (IsCat(d_feature_types, l.index)) {
167+
return l.fvalue == r.fvalue;
168+
}
169+
}
170+
return false;
171+
});
169172
}
170173
sorted_entries.resize(n_uniques);
171174

@@ -189,7 +192,7 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
189192
}
190193
});
191194
// Turn size into ptr.
192-
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), new_cuts_size.cend(),
195+
thrust::exclusive_scan(ctx->CUDACtx()->CTP(), new_cuts_size.cbegin(), new_cuts_size.cend(),
193196
d_cuts_ptr.data());
194197
}
195198
} // namespace detail
@@ -225,7 +228,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
225228
std::size_t ridx = dh::SegmentId(row_ptrs, element_idx);
226229
d_temp_weight[idx] = sample_weight[ridx + base_rowid];
227230
});
228-
detail::SortByWeight(&entry_weight, &sorted_entries);
231+
detail::SortByWeight(ctx, &entry_weight, &sorted_entries);
229232
} else {
230233
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(),
231234
detail::EntryCompareOp());
@@ -238,21 +241,21 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
238241
sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple {
239242
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
240243
});
241-
detail::GetColumnSizesScan(ctx->Device(), info.num_col_, num_cuts_per_feature,
244+
detail::GetColumnSizesScan(ctx->CUDACtx(), ctx->Device(), info.num_col_, num_cuts_per_feature,
242245
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
243246
&column_sizes_scan);
244247
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
245248
if (sketch_container->HasCategorical()) {
246249
auto p_weight = entry_weight.empty() ? nullptr : &entry_weight;
247-
detail::RemoveDuplicatedCategories(ctx->Device(), info, d_cuts_ptr, &sorted_entries, p_weight,
250+
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, p_weight,
248251
&column_sizes_scan);
249252
}
250253

251254
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
252255
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
253256

254257
// Add cuts into sketches
255-
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
258+
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
256259
h_cuts_ptr.back(), dh::ToSpan(entry_weight));
257260

258261
sorted_entries.clear();

0 commit comments

Comments
 (0)