@@ -106,26 +106,27 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_ro
106
106
return std::min (sketch_batch_num_elements, kIntMax );
107
107
}
108
108
109
- void SortByWeight (dh::device_vector<float >* weights, dh::device_vector<Entry>* sorted_entries) {
109
+ void SortByWeight (Context const * ctx, dh::device_vector<float >* weights,
110
+ dh::device_vector<Entry>* sorted_entries) {
110
111
// Sort both entries and wegihts.
111
- dh::XGBDeviceAllocator< char > alloc ;
112
+ auto cuctx = ctx-> CUDACtx () ;
112
113
CHECK_EQ (weights->size (), sorted_entries->size ());
113
- thrust::sort_by_key (thrust::cuda::par (alloc ), sorted_entries->begin (), sorted_entries->end (),
114
- weights-> begin (), detail::EntryCompareOp ());
114
+ thrust::sort_by_key (cuctx-> TP ( ), sorted_entries->begin (), sorted_entries->end (), weights-> begin (),
115
+ detail::EntryCompareOp ());
115
116
116
117
// Scan weights
117
- dh::XGBCachingDeviceAllocator<char > caching;
118
118
thrust::inclusive_scan_by_key (
119
- thrust::cuda::par (caching ), sorted_entries->begin (), sorted_entries->end (), weights->begin (),
119
+ cuctx-> CTP ( ), sorted_entries->begin (), sorted_entries->end (), weights->begin (),
120
120
weights->begin (),
121
121
[=] __device__ (const Entry& a, const Entry& b) { return a.index == b.index ; });
122
122
}
123
123
124
- void RemoveDuplicatedCategories (DeviceOrd device, MetaInfo const & info, Span<bst_idx_t > d_cuts_ptr,
124
+ void RemoveDuplicatedCategories (Context const * ctx, MetaInfo const & info,
125
+ Span<bst_idx_t > d_cuts_ptr,
125
126
dh::device_vector<Entry>* p_sorted_entries,
126
127
dh::device_vector<float >* p_sorted_weights,
127
128
dh::caching_device_vector<size_t >* p_column_sizes_scan) {
128
- info.feature_types .SetDevice (device );
129
+ info.feature_types .SetDevice (ctx-> Device () );
129
130
auto d_feature_types = info.feature_types .ConstDeviceSpan ();
130
131
CHECK (!d_feature_types.empty ());
131
132
auto & column_sizes_scan = *p_column_sizes_scan;
@@ -142,30 +143,32 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
142
143
auto d_sorted_weights = dh::ToSpan (*p_sorted_weights);
143
144
auto val_in_it = thrust::make_zip_iterator (d_sorted_entries.data (), d_sorted_weights.data ());
144
145
auto val_out_it = thrust::make_zip_iterator (d_sorted_entries.data (), d_sorted_weights.data ());
145
- n_uniques = dh::SegmentedUnique (
146
- column_sizes_scan.data ().get (), column_sizes_scan.data ().get () + column_sizes_scan.size (),
147
- val_in_it, val_in_it + sorted_entries.size (), new_column_scan.data ().get (), val_out_it,
148
- [=] __device__ (Pair const & l, Pair const & r) {
149
- Entry const & le = thrust::get<0 >(l);
150
- Entry const & re = thrust::get<0 >(r);
151
- if (le.index == re.index && IsCat (d_feature_types, le.index )) {
152
- return le.fvalue == re.fvalue ;
153
- }
154
- return false ;
155
- });
146
+ n_uniques =
147
+ dh::SegmentedUnique (ctx->CUDACtx ()->CTP (), column_sizes_scan.data ().get (),
148
+ column_sizes_scan.data ().get () + column_sizes_scan.size (), val_in_it,
149
+ val_in_it + sorted_entries.size (), new_column_scan.data ().get (),
150
+ val_out_it, [=] __device__ (Pair const & l, Pair const & r) {
151
+ Entry const & le = thrust::get<0 >(l);
152
+ Entry const & re = thrust::get<0 >(r);
153
+ if (le.index == re.index && IsCat (d_feature_types, le.index )) {
154
+ return le.fvalue == re.fvalue ;
155
+ }
156
+ return false ;
157
+ });
156
158
p_sorted_weights->resize (n_uniques);
157
159
} else {
158
- n_uniques = dh::SegmentedUnique (
159
- column_sizes_scan.data ().get (), column_sizes_scan.data ().get () + column_sizes_scan.size (),
160
- sorted_entries.begin (), sorted_entries.end (), new_column_scan.data ().get (),
161
- sorted_entries.begin (), [=] __device__ (Entry const & l, Entry const & r) {
162
- if (l.index == r.index ) {
163
- if (IsCat (d_feature_types, l.index )) {
164
- return l.fvalue == r.fvalue ;
165
- }
166
- }
167
- return false ;
168
- });
160
+ n_uniques = dh::SegmentedUnique (ctx->CUDACtx ()->CTP (), column_sizes_scan.data ().get (),
161
+ column_sizes_scan.data ().get () + column_sizes_scan.size (),
162
+ sorted_entries.begin (), sorted_entries.end (),
163
+ new_column_scan.data ().get (), sorted_entries.begin (),
164
+ [=] __device__ (Entry const & l, Entry const & r) {
165
+ if (l.index == r.index ) {
166
+ if (IsCat (d_feature_types, l.index )) {
167
+ return l.fvalue == r.fvalue ;
168
+ }
169
+ }
170
+ return false ;
171
+ });
169
172
}
170
173
sorted_entries.resize (n_uniques);
171
174
@@ -189,7 +192,7 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
189
192
}
190
193
});
191
194
// Turn size into ptr.
192
- thrust::exclusive_scan (thrust::device , new_cuts_size.cbegin (), new_cuts_size.cend (),
195
+ thrust::exclusive_scan (ctx-> CUDACtx ()-> CTP () , new_cuts_size.cbegin (), new_cuts_size.cend (),
193
196
d_cuts_ptr.data ());
194
197
}
195
198
} // namespace detail
@@ -225,7 +228,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
225
228
std::size_t ridx = dh::SegmentId (row_ptrs, element_idx);
226
229
d_temp_weight[idx] = sample_weight[ridx + base_rowid];
227
230
});
228
- detail::SortByWeight (&entry_weight, &sorted_entries);
231
+ detail::SortByWeight (ctx, &entry_weight, &sorted_entries);
229
232
} else {
230
233
thrust::sort (cuctx->TP (), sorted_entries.begin (), sorted_entries.end (),
231
234
detail::EntryCompareOp ());
@@ -238,21 +241,21 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
238
241
sorted_entries.data ().get (), [] __device__ (Entry const & e) -> data::COOTuple {
239
242
return {0 , e.index , e.fvalue }; // row_idx is not needed for scaning column size.
240
243
});
241
- detail::GetColumnSizesScan (ctx->Device (), info.num_col_ , num_cuts_per_feature,
244
+ detail::GetColumnSizesScan (ctx->CUDACtx (), ctx-> Device (), info.num_col_ , num_cuts_per_feature,
242
245
IterSpan{batch_it, sorted_entries.size ()}, dummy_is_valid, &cuts_ptr,
243
246
&column_sizes_scan);
244
247
auto d_cuts_ptr = cuts_ptr.DeviceSpan ();
245
248
if (sketch_container->HasCategorical ()) {
246
249
auto p_weight = entry_weight.empty () ? nullptr : &entry_weight;
247
- detail::RemoveDuplicatedCategories (ctx-> Device () , info, d_cuts_ptr, &sorted_entries, p_weight,
250
+ detail::RemoveDuplicatedCategories (ctx, info, d_cuts_ptr, &sorted_entries, p_weight,
248
251
&column_sizes_scan);
249
252
}
250
253
251
254
auto const & h_cuts_ptr = cuts_ptr.ConstHostVector ();
252
255
CHECK_EQ (d_cuts_ptr.size (), column_sizes_scan.size ());
253
256
254
257
// Add cuts into sketches
255
- sketch_container->Push (dh::ToSpan (sorted_entries), dh::ToSpan (column_sizes_scan), d_cuts_ptr,
258
+ sketch_container->Push (ctx, dh::ToSpan (sorted_entries), dh::ToSpan (column_sizes_scan), d_cuts_ptr,
256
259
h_cuts_ptr.back (), dh::ToSpan (entry_weight));
257
260
258
261
sorted_entries.clear ();
0 commit comments