Skip to content

Commit 14171e8

Browse files
authored
fix fp32 (#79)
Co-authored-by: Dmitry Razdoburdin <>
1 parent a15b93c commit 14171e8

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

plugin/sycl/common/linalg_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ bool Validate(DeviceOrd device, TensorView<T, D> t, Fn&& fn) {
103103
namespace linalg {
104104
template <typename T, int32_t D, typename Fn>
105105
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
106-
if (ctx->IsSycl()) {
106+
if (t.Device().IsSycl()) {
107107
sycl::linalg::ElementWiseKernel(t, fn);
108108
} else {
109109
ElementWiseKernelHost(t, ctx->Threads(), fn);

src/objective/multiclass_obj.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
110110
<< "Number of weights should be equal to number of data points.";
111111
}
112112
info.weights_.SetDevice(device);
113-
auto weights = common::MakeOptionalWeights(this->ctx_->Device(), info.weights_);
113+
auto weights = common::MakeOptionalWeights(device, info.weights_);
114114

115115
preds.SetDevice(device);
116116
auto predt = linalg::MakeTensorView(this->ctx_, &preds, n_samples, n_classes);

0 commit comments

Comments
 (0)