diff --git a/plugin/sycl/common/optional_weight.cc b/plugin/sycl/common/optional_weight.cc index aa984a152dc3..819f274cec6d 100644 --- a/plugin/sycl/common/optional_weight.cc +++ b/plugin/sycl/common/optional_weight.cc @@ -8,14 +8,13 @@ #include "../device_manager.h" namespace xgboost::common::sycl_impl { -double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) { - sycl::DeviceManager device_manager; - auto* qu = device_manager.GetQueue(ctx->Device()); +template +T ElementWiseSum(::sycl::queue* qu, OptionalWeights const& weights) { const auto* data = weights.Data(); - double result = 0; + T result = 0; { - ::sycl::buffer buff(&result, 1); + ::sycl::buffer buff(&result, 1); qu->submit([&](::sycl::handler& cgh) { auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>()); cgh.parallel_for<>(::sycl::range<1>(weights.Size()), reduction, @@ -28,4 +27,16 @@ double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) { return result; } + +double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) { + sycl::DeviceManager device_manager; + auto* qu = device_manager.GetQueue(ctx->Device()); + + bool has_fp64_support = qu->get_device().has(::sycl::aspect::fp64); + if (has_fp64_support) { + return ElementWiseSum(qu, weights); + } else { + return ElementWiseSum(qu, weights); + } +} } // namespace xgboost::common::sycl_impl