Skip to content

Commit 571696c

Browse files
authored
fix VecScaMul for fp32 (#80)
Co-authored-by: Dmitry Razdoburdin <>
1 parent 14171e8 commit 571696c

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

plugin/sycl/common/linalg_op.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView<float const>
3131
}).wait();
3232
}
3333

34-
void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double mul) {
35-
sycl::DeviceManager device_manager;
36-
auto* qu = device_manager.GetQueue(ctx->Device());
37-
34+
template <typename T>
35+
void VecScaMul(::sycl::queue* qu, xgboost::linalg::VectorView<float> x, T mul) {
3836
qu->submit([&](::sycl::handler& cgh) {
3937
cgh.parallel_for<>(::sycl::range<1>(x.Size()),
4038
[=](::sycl::id<1> pid) {
@@ -47,6 +45,15 @@ void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double
4745

4846
namespace xgboost::linalg::sycl_impl {
4947
void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double mul) {
50-
xgboost::sycl::linalg::VecScaMul(ctx, x, mul);
48+
sycl::DeviceManager device_manager;
49+
auto* qu = device_manager.GetQueue(ctx->Device());
50+
51+
bool has_fp64_support = qu->get_device().has(::sycl::aspect::fp64);
52+
if (has_fp64_support) {
53+
xgboost::sycl::linalg::VecScaMul(qu, x, mul);
54+
} else {
55+
float mul_fp32 = mul;
56+
xgboost::sycl::linalg::VecScaMul(qu, x, mul_fp32);
57+
}
5158
}
5259
} // namespace xgboost::linalg::sycl_impl

0 commit comments

Comments
 (0)