1
1
/* !
2
- * Copyright by Contributors 2017-2023
2
+ * Copyright by Contributors 2017-2024
3
3
*/
4
- #pragma GCC diagnostic push
5
- #pragma GCC diagnostic ignored "-Wtautological-constant-compare"
6
- #pragma GCC diagnostic ignored "-W#pragma-messages"
7
- #pragma GCC diagnostic pop
4
+ #include < dmlc/timer.h>
5
+ // #pragma GCC diagnostic push
6
+ // #pragma GCC diagnostic ignored "-Wtautological-constant-compare"
7
+ // #pragma GCC diagnostic ignored "-W#pragma-messages"
8
+ // #pragma GCC diagnostic pop
8
9
9
10
#include < cstddef>
10
11
#include < limits>
@@ -158,6 +159,8 @@ float GetLeafWeight(const Node* nodes, const float* fval_buff) {
158
159
159
160
template <bool any_missing>
160
161
void DevicePredictInternal (::sycl::queue* qu,
162
+ USMVector<float , MemoryType::on_device>* fval_buff,
163
+ USMVector<uint8_t , MemoryType::on_device>* miss_buff,
161
164
const sycl::DeviceMatrix& dmat,
162
165
HostDeviceVector<float >* out_preds,
163
166
const gbm::GBTreeModel& model,
@@ -178,15 +181,17 @@ void DevicePredictInternal(::sycl::queue* qu,
178
181
int num_rows = dmat.row_ptr .Size () - 1 ;
179
182
int num_group = model.learner_model_param ->num_output_group ;
180
183
181
- USMVector<float , MemoryType::on_device> fval_buff (qu, num_features * num_rows);
182
- USMVector<uint8_t , MemoryType::on_device> miss_buff;
183
- auto * fval_buff_ptr = fval_buff.Data ();
184
+ bool update_buffs = !dmat.is_from_cache ;
184
185
185
186
std::vector<::sycl::event> events (1 );
186
- if constexpr (any_missing) {
187
- miss_buff.Resize (qu, num_features * num_rows, 1 , &events[0 ]);
187
+ if (update_buffs) {
188
+ fval_buff->Resize (qu, num_features * num_rows);
189
+ if constexpr (any_missing) {
190
+ miss_buff->Resize (qu, num_features * num_rows, 1 , &events[0 ]);
191
+ }
188
192
}
189
- auto * miss_buff_ptr = miss_buff.Data ();
193
+ auto * fval_buff_ptr = fval_buff->Data ();
194
+ auto * miss_buff_ptr = miss_buff->Data ();
190
195
191
196
auto & out_preds_vec = out_preds->HostVector ();
192
197
::sycl::buffer<float , 1 > out_preds_buf (out_preds_vec.data (), out_preds_vec.size ());
@@ -198,12 +203,14 @@ void DevicePredictInternal(::sycl::queue* qu,
198
203
auto * fval_buff_row_ptr = fval_buff_ptr + num_features * row_idx;
199
204
auto * miss_buff_row_ptr = miss_buff_ptr + num_features * row_idx;
200
205
201
- const Entry* first_entry = data + row_ptr[row_idx];
202
- const Entry* last_entry = data + row_ptr[row_idx + 1 ];
203
- for (const Entry* entry = first_entry; entry < last_entry; entry += 1 ) {
204
- fval_buff_row_ptr[entry->index ] = entry->fvalue ;
205
- if constexpr (any_missing) {
206
- miss_buff_row_ptr[entry->index ] = 0 ;
206
+ if (update_buffs) {
207
+ const Entry* first_entry = data + row_ptr[row_idx];
208
+ const Entry* last_entry = data + row_ptr[row_idx + 1 ];
209
+ for (const Entry* entry = first_entry; entry < last_entry; entry += 1 ) {
210
+ fval_buff_row_ptr[entry->index ] = entry->fvalue ;
211
+ if constexpr (any_missing) {
212
+ miss_buff_row_ptr[entry->index ] = 0 ;
213
+ }
207
214
}
208
215
}
209
216
@@ -241,6 +248,7 @@ class Predictor : public xgboost::Predictor {
241
248
void InitOutPredictions (const MetaInfo& info,
242
249
HostDeviceVector<bst_float>* out_preds,
243
250
const gbm::GBTreeModel& model) const override {
251
+ predictor_monitor_.Start (" InitOutPredictions" );
244
252
CHECK_NE (model.learner_model_param ->num_output_group , 0 );
245
253
size_t n = model.learner_model_param ->num_output_group * info.num_row_ ;
246
254
const auto & base_margin = info.base_margin_ .Data ()->HostVector ();
@@ -268,33 +276,40 @@ class Predictor : public xgboost::Predictor {
268
276
}
269
277
std::fill (out_preds_h.begin (), out_preds_h.end (), base_score);
270
278
}
279
+ predictor_monitor_.Stop (" InitOutPredictions" );
271
280
}
272
281
273
282
explicit Predictor (Context const * context) :
274
283
xgboost::Predictor::Predictor{context},
275
- cpu_predictor (xgboost::Predictor::Create(" cpu_predictor" , context)) {}
284
+ cpu_predictor (xgboost::Predictor::Create(" cpu_predictor" , context)) {
285
+ predictor_monitor_.Init (" SyclPredictor" );
286
+ }
276
287
277
288
void PredictBatch (DMatrix *dmat, PredictionCacheEntry *predts,
278
289
const gbm::GBTreeModel &model, uint32_t tree_begin,
279
- uint32_t tree_end = 0 ) const override {
290
+ uint32_t tree_end = 0 , bool training = false ) const override {
280
291
::sycl::queue qu = device_manager.GetQueue (ctx_->Device ());
281
- // TODO(razdoburdin): remove temporary workaround after cache fix
282
- sycl::DeviceMatrix device_matrix;
283
- device_matrix. Init (qu, dmat );
292
+ predictor_monitor_. Start ( " InitDeviceMatrix " );
293
+ device_matrix. Init (qu, dmat, training) ;
294
+ predictor_monitor_. Stop ( " InitDeviceMatrix " );
284
295
285
296
auto * out_preds = &predts->predictions ;
286
297
if (tree_end == 0 ) {
287
298
tree_end = model.trees .size ();
288
299
}
289
300
301
+ predictor_monitor_.Start (" DevicePredictInternal" );
290
302
if (tree_begin < tree_end) {
291
303
const bool any_missing = !(dmat->IsDense ());
292
304
if (any_missing) {
293
- DevicePredictInternal<true >(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
305
+ DevicePredictInternal<true >(&qu, &fval_buff, &miss_buff, device_matrix,
306
+ out_preds, model, tree_begin, tree_end);
294
307
} else {
295
- DevicePredictInternal<false >(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
308
+ DevicePredictInternal<false >(&qu, &fval_buff, &miss_buff, device_matrix,
309
+ out_preds, model, tree_begin, tree_end);
296
310
}
297
311
}
312
+ predictor_monitor_.Stop (" DevicePredictInternal" );
298
313
}
299
314
300
315
bool InplacePredict (std::shared_ptr<DMatrix> p_m,
@@ -341,7 +356,11 @@ class Predictor : public xgboost::Predictor {
341
356
342
357
private:
343
358
DeviceManager device_manager;
359
+ mutable sycl::DeviceMatrix device_matrix;
360
+ mutable USMVector<float , MemoryType::on_device> fval_buff;
361
+ mutable USMVector<uint8_t , MemoryType::on_device> miss_buff;
344
362
363
+ mutable xgboost::common::Monitor predictor_monitor_;
345
364
std::unique_ptr<xgboost::Predictor> cpu_predictor;
346
365
};
347
366
0 commit comments