Skip to content

Commit cb02c22

Browse files
[SYCL]. Fixes for the upcoming v3.1 (dmlc#11725) (dmlc#11738)
--------- Co-authored-by: Dmitry Razdoburdin <[email protected]>
1 parent 74042bf commit cb02c22

26 files changed

+357
-123
lines changed

include/xgboost/linalg.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -591,13 +591,13 @@ auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data,
591591

592592
template <typename T, typename... S>
593593
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
594-
auto span = ctx->IsCUDA() ? data->DeviceSpan() : data->HostSpan();
594+
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
595595
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
596596
}
597597

598598
template <typename T, typename... S>
599599
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
600-
auto span = ctx->IsCUDA() ? data->ConstDeviceSpan() : data->ConstHostSpan();
600+
auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan();
601601
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
602602
}
603603

@@ -647,13 +647,13 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {
647647

648648
template <typename T>
649649
auto MakeVec(HostDeviceVector<T> *data) {
650-
return MakeVec(data->Device().IsCUDA() ? data->DevicePointer() : data->HostPointer(),
650+
return MakeVec(data->Device().IsCPU() ? data->HostPointer() : data->DevicePointer(),
651651
data->Size(), data->Device());
652652
}
653653

654654
template <typename T>
655655
auto MakeVec(HostDeviceVector<T> const *data) {
656-
return MakeVec(data->Device().IsCUDA() ? data->ConstDevicePointer() : data->ConstHostPointer(),
656+
return MakeVec(data->Device().IsCPU() ? data->ConstHostPointer() : data->ConstDevicePointer(),
657657
data->Size(), data->Device());
658658
}
659659

@@ -759,7 +759,7 @@ class Tensor {
759759
for (auto i = D; i < kDim; ++i) {
760760
shape_[i] = 1;
761761
}
762-
if (device.IsCUDA()) {
762+
if (!device.IsCPU()) {
763763
data_.SetDevice(device);
764764
data_.ConstDevicePointer(); // Pull to device;
765765
}
@@ -788,11 +788,11 @@ class Tensor {
788788
shape_[i] = 1;
789789
}
790790
auto size = detail::CalcSize(shape_);
791-
if (device.IsCUDA()) {
791+
if (!device.IsCPU()) {
792792
data_.SetDevice(device);
793793
}
794794
data_.Resize(size);
795-
if (device.IsCUDA()) {
795+
if (!device.IsCPU()) {
796796
data_.DevicePointer(); // Pull to device
797797
}
798798
}

plugin/sycl/common/host_device_vector.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "../device_manager.h"
1818
#include "../data.h"
19+
#include "../predictor/node.h"
1920

2021
namespace xgboost {
2122
template <typename T>
@@ -405,6 +406,7 @@ template class HostDeviceVector<FeatureType>;
405406
template class HostDeviceVector<Entry>;
406407
template class HostDeviceVector<bst_idx_t>;
407408
template class HostDeviceVector<uint32_t>; // bst_feature_t
409+
template class HostDeviceVector<sycl::predictor::Node>;
408410

409411
} // namespace xgboost
410412

plugin/sycl/common/linalg_op.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/**
2+
* Copyright 2021-2025, XGBoost Contributors
3+
* \file linalg_op.h
4+
*/
5+
6+
#include "../data.h"
7+
#include "../device_manager.h"
8+
9+
#include "../../../src/common/optional_weight.h" // for OptionalWeights
10+
#include "xgboost/context.h" // for Context
11+
12+
#include <sycl/sycl.hpp>
13+
14+
namespace xgboost::sycl::linalg {
15+
16+
void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView<float const> indices,
17+
xgboost::common::OptionalWeights const& weights,
18+
xgboost::linalg::VectorView<float> bins) {
19+
sycl::DeviceManager device_manager;
20+
auto* qu = device_manager.GetQueue(ctx->Device());
21+
22+
qu->submit([&](::sycl::handler& cgh) {
23+
cgh.parallel_for<>(::sycl::range<1>(indices.Size()),
24+
[=](::sycl::id<1> pid) {
25+
const size_t i = pid[0];
26+
auto y = indices(i);
27+
auto w = weights[i];
28+
AtomicRef<float> bin_val(const_cast<float&>(bins(static_cast<std::size_t>(y))));
29+
bin_val += w;
30+
});
31+
}).wait();
32+
}
33+
34+
void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double mul) {
35+
sycl::DeviceManager device_manager;
36+
auto* qu = device_manager.GetQueue(ctx->Device());
37+
38+
qu->submit([&](::sycl::handler& cgh) {
39+
cgh.parallel_for<>(::sycl::range<1>(x.Size()),
40+
[=](::sycl::id<1> pid) {
41+
const size_t i = pid[0];
42+
const_cast<float&>(x(i)) *= mul;
43+
});
44+
}).wait();
45+
}
46+
} // namespace xgboost::sycl::linalg
47+
48+
namespace xgboost::linalg::sycl_impl {
49+
void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double mul) {
50+
xgboost::sycl::linalg::VecScaMul(ctx, x, mul);
51+
}
52+
} // namespace xgboost::linalg::sycl_impl
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*!
2+
* Copyright by Contributors 2017-2025
3+
*/
4+
#include <sycl/sycl.hpp>
5+
6+
#include "../../../src/common/optional_weight.h"
7+
8+
#include "../device_manager.h"
9+
10+
namespace xgboost::common::sycl_impl {
11+
double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
12+
sycl::DeviceManager device_manager;
13+
auto* qu = device_manager.GetQueue(ctx->Device());
14+
15+
const auto* data = weights.Data();
16+
double result = 0;
17+
{
18+
::sycl::buffer<double> buff(&result, 1);
19+
qu->submit([&](::sycl::handler& cgh) {
20+
auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>());
21+
cgh.parallel_for<>(::sycl::range<1>(weights.Size()), reduction,
22+
[=](::sycl::id<1> pid, auto& sum) {
23+
size_t i = pid[0];
24+
sum += data[i];
25+
});
26+
}).wait_and_throw();
27+
}
28+
29+
return result;
30+
}
31+
} // namespace xgboost::common::sycl_impl

plugin/sycl/device_properties.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class DeviceProperties {
4747
size_t l2_size = 0;
4848
float l2_size_per_eu = 0;
4949

50+
DeviceProperties():
51+
is_gpu(false) {}
52+
5053
explicit DeviceProperties(const ::sycl::device& device):
5154
is_gpu(device.is_gpu()),
5255
usm_host_allocations(device.has(::sycl::aspect::usm_host_allocations)),

plugin/sycl/predictor/node.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*!
2+
* Copyright by Contributors 2017-2025
3+
* \file node.h
4+
*/
5+
#ifndef PLUGIN_SYCL_PREDICTOR_NODE_H_
6+
#define PLUGIN_SYCL_PREDICTOR_NODE_H_
7+
8+
#include "../../src/gbm/gbtree_model.h"
9+
10+
namespace xgboost {
11+
namespace sycl {
12+
namespace predictor {
13+
14+
union NodeValue {
15+
float leaf_weight;
16+
float fvalue;
17+
};
18+
19+
class Node {
20+
int fidx;
21+
int left_child_idx;
22+
int right_child_idx;
23+
NodeValue val;
24+
25+
public:
26+
Node() = default;
27+
28+
explicit Node(const RegTree::Node& n) {
29+
left_child_idx = n.LeftChild();
30+
right_child_idx = n.RightChild();
31+
fidx = n.SplitIndex();
32+
if (n.DefaultLeft()) {
33+
fidx |= (1U << 31);
34+
}
35+
36+
if (n.IsLeaf()) {
37+
val.leaf_weight = n.LeafValue();
38+
} else {
39+
val.fvalue = n.SplitCond();
40+
}
41+
}
42+
43+
int LeftChildIdx() const {return left_child_idx; }
44+
45+
int RightChildIdx() const {return right_child_idx; }
46+
47+
bool IsLeaf() const { return left_child_idx == -1; }
48+
49+
int GetFidx() const { return fidx & ((1U << 31) - 1U); }
50+
51+
bool MissingLeft() const { return (fidx >> 31) != 0; }
52+
53+
int MissingIdx() const {
54+
if (MissingLeft()) {
55+
return left_child_idx;
56+
} else {
57+
return right_child_idx;
58+
}
59+
}
60+
61+
float GetFvalue() const { return val.fvalue; }
62+
63+
float GetWeight() const { return val.leaf_weight; }
64+
};
65+
66+
} // namespace predictor
67+
} // namespace sycl
68+
} // namespace xgboost
69+
#endif // PLUGIN_SYCL_PREDICTOR_NODE_H_

0 commit comments

Comments
 (0)