@@ -31,10 +31,8 @@ void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView<float const>
3131 }).wait ();
3232}
3333
34- void VecScaMul (Context const * ctx, xgboost::linalg::VectorView<float > x, double mul) {
35- sycl::DeviceManager device_manager;
36- auto * qu = device_manager.GetQueue (ctx->Device ());
37-
34+ template <typename T>
35+ void VecScaMul (::sycl::queue* qu, xgboost::linalg::VectorView<float > x, T mul) {
3836 qu->submit ([&](::sycl::handler& cgh) {
3937 cgh.parallel_for <>(::sycl::range<1 >(x.Size ()),
4038 [=](::sycl::id<1 > pid) {
@@ -47,6 +45,15 @@ void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double
4745
4846namespace xgboost ::linalg::sycl_impl {
4947void VecScaMul (Context const * ctx, xgboost::linalg::VectorView<float > x, double mul) {
50- xgboost::sycl::linalg::VecScaMul (ctx, x, mul);
48+ sycl::DeviceManager device_manager;
49+ auto * qu = device_manager.GetQueue (ctx->Device ());
50+
51+ bool has_fp64_support = qu->get_device ().has (::sycl::aspect::fp64);
52+ if (has_fp64_support) {
53+ xgboost::sycl::linalg::VecScaMul (qu, x, mul);
54+ } else {
55+ float mul_fp32 = mul;
56+ xgboost::sycl::linalg::VecScaMul (qu, x, mul_fp32);
57+ }
5158}
5259} // namespace xgboost::linalg::sycl_impl
0 commit comments