diff --git a/plugin/sycl/context_helper.cc b/plugin/sycl/context_helper.cc index d5ced146187c..9ff77e5ca90f 100644 --- a/plugin/sycl/context_helper.cc +++ b/plugin/sycl/context_helper.cc @@ -14,7 +14,11 @@ namespace sycl { DeviceOrd DeviceFP64(const DeviceOrd& device) { DeviceManager device_manager; - bool support_fp64 = device_manager.GetQueue(device)->get_device().has(::sycl::aspect::fp64); + bool support_fp64 = true; + if (device.IsSycl()) { + support_fp64 = device_manager.GetQueue(device)->get_device().has(::sycl::aspect::fp64); + } + if (support_fp64) { return device; } else { diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 50dea22e8357..7e7f1c21a4ec 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -113,7 +113,10 @@ class SoftmaxMultiClassObj : public ObjFunction { auto weights = common::MakeOptionalWeights(device, info.weights_); preds.SetDevice(device); - auto predt = linalg::MakeTensorView(this->ctx_, &preds, n_samples, n_classes); + Context cpu_context = Context(); + auto predt = linalg::MakeTensorView( + device == ctx_->Device() ? this->ctx_ : &cpu_context, + &preds, n_samples, n_classes); CHECK_EQ(labels.Shape(1), 1); auto y1d = labels.Slice(linalg::All(), 0); CHECK_EQ(y1d.Shape(0), info.num_row_);