Skip to content

Commit f6bce94

Browse files
authored
[backport][sycl] Fix set device (dmlc#11712) (dmlc#11713)
1 parent 718066b commit f6bce94

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

plugin/sycl/data/gradient_index.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
6161
BinIdxType* sort_data = reinterpret_cast<BinIdxType*>(sort_buff.Data());
6262

6363
for (auto &batch : dmat->GetBatches<SparsePage>()) {
64+
batch.data.SetDevice(ctx->Device());
65+
batch.offset.SetDevice(ctx->Device());
6466
const xgboost::Entry *data_ptr = batch.data.ConstDevicePointer();
6567
const bst_idx_t *offset_vec = batch.offset.ConstDevicePointer();
6668
size_t batch_size = batch.Size();

plugin/sycl/predictor/predictor.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ class Predictor : public xgboost::Predictor {
206206
xgboost::Predictor::Predictor{context},
207207
cpu_predictor(xgboost::Predictor::Create("cpu_predictor", context)),
208208
qu_(device_manager.GetQueue(context->Device())),
209-
device_prop_(qu_->get_device()) {}
209+
device_prop_(qu_->get_device()) {
210+
device_model.SetDevice(context->Device());
211+
}
210212

211213
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
212214
const gbm::GBTreeModel &model, bst_tree_t tree_begin,

0 commit comments

Comments
 (0)