Skip to content

Commit 7faf800

Browse files
author
Dmitry Razdoburdin
committed
fp32 fix
1 parent 571696c commit 7faf800

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

plugin/sycl/context_helper.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ namespace sycl {
1414

1515
DeviceOrd DeviceFP64(const DeviceOrd& device) {
1616
DeviceManager device_manager;
17-
bool support_fp64 = device_manager.GetQueue(device)->get_device().has(::sycl::aspect::fp64);
17+
bool support_fp64 = true;
18+
if (device.IsSycl()) {
19+
support_fp64 = device_manager.GetQueue(device)->get_device().has(::sycl::aspect::fp64);
20+
}
21+
1822
if (support_fp64) {
1923
return device;
2024
} else {

src/objective/multiclass_obj.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ class SoftmaxMultiClassObj : public ObjFunction {
113113
auto weights = common::MakeOptionalWeights(device, info.weights_);
114114

115115
preds.SetDevice(device);
116-
auto predt = linalg::MakeTensorView(this->ctx_, &preds, n_samples, n_classes);
116+
Context cpu_context = Context();
117+
auto predt = linalg::MakeTensorView(
118+
device == ctx_->Device() ? this->ctx_ : &cpu_context,
119+
&preds, n_samples, n_classes);
117120
CHECK_EQ(labels.Shape(1), 1);
118121
auto y1d = labels.Slice(linalg::All(), 0);
119122
CHECK_EQ(y1d.Shape(0), info.num_row_);

0 commit comments

Comments
 (0)