diff --git a/plugin/sycl/common/linalg_op.cc b/plugin/sycl/common/linalg_op.cc index 55eca035ced8..3dc3c17cf17f 100644 --- a/plugin/sycl/common/linalg_op.cc +++ b/plugin/sycl/common/linalg_op.cc @@ -31,10 +31,8 @@ void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView }).wait(); } -void VecScaMul(Context const* ctx, xgboost::linalg::VectorView x, double mul) { - sycl::DeviceManager device_manager; - auto* qu = device_manager.GetQueue(ctx->Device()); - +template +void VecScaMul(::sycl::queue* qu, xgboost::linalg::VectorView x, T mul) { qu->submit([&](::sycl::handler& cgh) { cgh.parallel_for<>(::sycl::range<1>(x.Size()), [=](::sycl::id<1> pid) { @@ -47,6 +45,15 @@ void VecScaMul(Context const* ctx, xgboost::linalg::VectorView x, double namespace xgboost::linalg::sycl_impl { void VecScaMul(Context const* ctx, xgboost::linalg::VectorView x, double mul) { - xgboost::sycl::linalg::VecScaMul(ctx, x, mul); + sycl::DeviceManager device_manager; + auto* qu = device_manager.GetQueue(ctx->Device()); + + bool has_fp64_support = qu->get_device().has(::sycl::aspect::fp64); + if (has_fp64_support) { + xgboost::sycl::linalg::VecScaMul(qu, x, mul); + } else { + float mul_fp32 = mul; + xgboost::sycl::linalg::VecScaMul(qu, x, mul_fp32); + } } } // namespace xgboost::linalg::sycl_impl