Skip to content

Commit 4a28128

Browse files
authored
[MT] Add device storage to multi-target tree. (dmlc#11277)
1 parent 76187c0 commit 4a28128

File tree

2 files changed

+135
-87
lines changed

2 files changed

+135
-87
lines changed

include/xgboost/multi_target_tree_model.h

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,93 @@
11
/**
2-
* Copyright 2023 by XGBoost contributors
2+
* Copyright 2023-2025, XGBoost contributors
33
*
4-
* \brief Core data structure for multi-target trees.
4+
* @brief Core data structure for multi-target trees.
55
*/
66
#ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
77
#define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
8-
#include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
9-
#include <xgboost/context.h> // for Context
10-
#include <xgboost/linalg.h> // for VectorView
11-
#include <xgboost/model.h> // for Model
12-
#include <xgboost/span.h> // for Span
138

14-
#include <cinttypes> // for uint8_t
15-
#include <cstddef> // for size_t
16-
#include <vector> // for vector
9+
#include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
10+
#include <xgboost/context.h> // for Context
11+
#include <xgboost/host_device_vector.h> // for HostDeviceVector
12+
#include <xgboost/linalg.h> // for VectorView
13+
#include <xgboost/model.h> // for Model
14+
#include <xgboost/span.h> // for Span
15+
16+
#include <cstddef> // for size_t
17+
#include <cstdint> // for uint8_t
18+
#include <vector> // for vector
1719

1820
namespace xgboost {
1921
struct TreeParam;
2022
/**
21-
* \brief Tree structure for multi-target model.
23+
* @brief Tree structure for multi-target model.
2224
*/
2325
class MultiTargetTree : public Model {
2426
public:
2527
static bst_node_t constexpr InvalidNodeId() { return -1; }
2628

2729
private:
2830
TreeParam const* param_;
29-
std::vector<bst_node_t> left_;
30-
std::vector<bst_node_t> right_;
31-
std::vector<bst_node_t> parent_;
32-
std::vector<bst_feature_t> split_index_;
33-
std::vector<std::uint8_t> default_left_;
34-
std::vector<float> split_conds_;
35-
std::vector<float> weights_;
31+
HostDeviceVector<bst_node_t> left_;
32+
HostDeviceVector<bst_node_t> right_;
33+
HostDeviceVector<bst_node_t> parent_;
34+
HostDeviceVector<bst_feature_t> split_index_;
35+
HostDeviceVector<std::uint8_t> default_left_;
36+
HostDeviceVector<float> split_conds_;
37+
HostDeviceVector<float> weights_;
3638

3739
[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
3840
auto beg = nidx * this->NumTarget();
39-
auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
41+
auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumTarget());
4042
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
4143
}
4244
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
4345
auto beg = nidx * this->NumTarget();
44-
auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
46+
auto v = this->weights_.HostSpan().subspan(beg, this->NumTarget());
4547
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
4648
}
4749

4850
public:
4951
explicit MultiTargetTree(TreeParam const* param);
52+
MultiTargetTree(MultiTargetTree const& that);
53+
MultiTargetTree& operator=(MultiTargetTree const& that) = delete;
54+
MultiTargetTree(MultiTargetTree&& that) = default;
55+
MultiTargetTree& operator=(MultiTargetTree&& that) = default;
56+
5057
/**
51-
* \brief Set the weight for a leaf.
58+
* @brief Set the weight for a leaf.
5259
*/
5360
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight);
5461
/**
55-
* \brief Expand a leaf into split node.
62+
* @brief Expand a leaf into split node.
5663
*/
5764
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
5865
linalg::VectorView<float const> base_weight,
5966
linalg::VectorView<float const> left_weight,
6067
linalg::VectorView<float const> right_weight);
6168

62-
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); }
63-
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); }
64-
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); }
65-
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); }
69+
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
70+
return left_.ConstHostVector()[nidx] == InvalidNodeId();
71+
}
72+
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
73+
return parent_.ConstHostVector().at(nidx);
74+
}
75+
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
76+
return left_.ConstHostVector().at(nidx);
77+
}
78+
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
79+
return right_.ConstHostVector().at(nidx);
80+
}
6681

67-
[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; }
68-
[[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; }
69-
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; }
82+
[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
83+
return split_index_.ConstHostVector()[nidx];
84+
}
85+
[[nodiscard]] float SplitCond(bst_node_t nidx) const {
86+
return split_conds_.ConstHostVector()[nidx];
87+
}
88+
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
89+
return default_left_.ConstHostVector()[nidx];
90+
}
7091
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
7192
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
7293
}

src/tree/multi_target_tree_model.cc

Lines changed: 85 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
/**
2-
* Copyright 2023 by XGBoost Contributors
2+
* Copyright 2023-2025, XGBoost Contributors
33
*/
44
#include "xgboost/multi_target_tree_model.h"
55

6-
#include <algorithm> // for copy_n
7-
#include <cstddef> // for size_t
8-
#include <cstdint> // for int32_t, uint8_t
9-
#include <limits> // for numeric_limits
10-
#include <string_view> // for string_view
11-
#include <utility> // for move
12-
#include <vector> // for vector
13-
14-
#include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
15-
#include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
16-
#include "xgboost/json.h" // for Json, get, Object, Number, Integer, ...
6+
#include <algorithm> // for copy_n
7+
#include <cstddef> // for size_t
8+
#include <cstdint> // for int32_t, uint8_t
9+
#include <limits> // for numeric_limits
10+
#include <string_view> // for string_view
11+
#include <utility> // for move
12+
#include <vector> // for vector
13+
14+
#include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
15+
#include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
16+
#include "xgboost/json.h" // for Json, get, Object, Number, Integer, ...
1717
#include "xgboost/logging.h"
1818
#include "xgboost/tree_model.h" // for TreeParam
1919

@@ -30,27 +30,47 @@ MultiTargetTree::MultiTargetTree(TreeParam const* param)
3030
CHECK_GT(param_->size_leaf_vector, 1);
3131
}
3232

33+
MultiTargetTree::MultiTargetTree(MultiTargetTree const& that)
34+
: param_{that.param_},
35+
left_(that.left_.Size(), 0, that.left_.Device()),
36+
right_(that.right_.Size(), 0, that.right_.Device()),
37+
parent_(that.parent_.Size(), 0, that.parent_.Device()),
38+
split_index_(that.split_index_.Size(), 0, that.split_index_.Device()),
39+
default_left_(that.default_left_.Size(), 0, that.default_left_.Device()),
40+
split_conds_(that.split_conds_.Size(), 0, that.split_conds_.Device()),
41+
weights_(that.weights_.Size(), 0, that.weights_.Device()) {
42+
this->left_.Copy(that.left_);
43+
this->right_.Copy(that.right_);
44+
this->parent_.Copy(that.parent_);
45+
this->split_index_.Copy(that.split_index_);
46+
this->default_left_.Copy(that.default_left_);
47+
this->split_conds_.Copy(that.split_conds_);
48+
this->weights_.Copy(that.weights_);
49+
}
50+
3351
template <bool typed, bool feature_is_64>
34-
void LoadModelImpl(Json const& in, std::vector<float>* p_weights, std::vector<bst_node_t>* p_lefts,
35-
std::vector<bst_node_t>* p_rights, std::vector<bst_node_t>* p_parents,
36-
std::vector<float>* p_conds, std::vector<bst_feature_t>* p_fidx,
37-
std::vector<std::uint8_t>* p_dft_left) {
52+
void LoadModelImpl(Json const& in, HostDeviceVector<float>* p_weights,
53+
HostDeviceVector<bst_node_t>* p_lefts, HostDeviceVector<bst_node_t>* p_rights,
54+
HostDeviceVector<bst_node_t>* p_parents, HostDeviceVector<float>* p_conds,
55+
HostDeviceVector<bst_feature_t>* p_fidx,
56+
HostDeviceVector<std::uint8_t>* p_dft_left) {
3857
namespace tf = tree_field;
3958

40-
auto get_float = [&](std::string_view name, std::vector<float>* p_out) {
59+
auto get_float = [&](std::string_view name, HostDeviceVector<float>* p_out) {
4160
auto& values = get<FloatArrayT<typed>>(get<Object const>(in).find(name)->second);
4261
auto& out = *p_out;
43-
out.resize(values.size());
62+
out.Resize(values.size());
63+
auto& h_out = out.HostVector();
4464
for (std::size_t i = 0; i < values.size(); ++i) {
45-
out[i] = GetElem<Number>(values, i);
65+
h_out[i] = GetElem<Number>(values, i);
4666
}
4767
};
4868
get_float(tf::kBaseWeight, p_weights);
4969
get_float(tf::kSplitCond, p_conds);
5070

51-
auto get_nidx = [&](std::string_view name, std::vector<bst_node_t>* p_nidx) {
71+
auto get_nidx = [&](std::string_view name, HostDeviceVector<bst_node_t>* p_nidx) {
5272
auto& nidx = get<I32ArrayT<typed>>(get<Object const>(in).find(name)->second);
53-
auto& out_nidx = *p_nidx;
73+
auto& out_nidx = p_nidx->HostVector();
5474
out_nidx.resize(nidx.size());
5575
for (std::size_t i = 0; i < nidx.size(); ++i) {
5676
out_nidx[i] = GetElem<Integer>(nidx, i);
@@ -61,15 +81,15 @@ void LoadModelImpl(Json const& in, std::vector<float>* p_weights, std::vector<bs
6181
get_nidx(tf::kParent, p_parents);
6282

6383
auto const& splits = get<IndexArrayT<typed, feature_is_64> const>(in[tf::kSplitIdx]);
64-
p_fidx->resize(splits.size());
65-
auto& out_fidx = *p_fidx;
84+
p_fidx->Resize(splits.size());
85+
auto& out_fidx = p_fidx->HostVector();
6686
for (std::size_t i = 0; i < splits.size(); ++i) {
6787
out_fidx[i] = GetElem<Integer>(splits, i);
6888
}
6989

7090
auto const& dft_left = get<U8ArrayT<typed> const>(in[tf::kDftLeft]);
71-
auto& out_dft_l = *p_dft_left;
72-
out_dft_l.resize(dft_left.size());
91+
p_dft_left->Resize(dft_left.size());
92+
auto& out_dft_l = p_dft_left->HostVector();
7393
for (std::size_t i = 0; i < dft_left.size(); ++i) {
7494
out_dft_l[i] = GetElem<Boolean>(dft_left, i);
7595
}
@@ -109,19 +129,25 @@ void MultiTargetTree::SaveModel(Json* p_out) const {
109129
U8Array default_left(n_nodes);
110130
F32Array weights(n_nodes * this->NumTarget());
111131

132+
auto const& h_left = this->left_.ConstHostVector();
133+
auto const& h_right = this->right_.ConstHostVector();
134+
auto const& h_parent = this->parent_.ConstHostVector();
135+
auto const& h_split_index = this->split_index_.ConstHostVector();
136+
auto const& h_split_conds = this->split_conds_.ConstHostVector();
137+
auto const& h_default_left = this->default_left_.ConstHostVector();
112138
auto save_tree = [&](auto* p_indices_array) {
113139
auto& indices_array = *p_indices_array;
114140
for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) {
115-
CHECK_LT(nidx, left_.size());
116-
lefts.Set(nidx, left_[nidx]);
117-
CHECK_LT(nidx, right_.size());
118-
rights.Set(nidx, right_[nidx]);
119-
CHECK_LT(nidx, parent_.size());
120-
parents.Set(nidx, parent_[nidx]);
121-
CHECK_LT(nidx, split_index_.size());
122-
indices_array.Set(nidx, split_index_[nidx]);
123-
conds.Set(nidx, split_conds_[nidx]);
124-
default_left.Set(nidx, default_left_[nidx]);
141+
CHECK_LT(nidx, left_.Size());
142+
lefts.Set(nidx, h_left[nidx]);
143+
CHECK_LT(nidx, right_.Size());
144+
rights.Set(nidx, h_right[nidx]);
145+
CHECK_LT(nidx, parent_.Size());
146+
parents.Set(nidx, h_parent[nidx]);
147+
CHECK_LT(nidx, split_index_.Size());
148+
indices_array.Set(nidx, h_split_index[nidx]);
149+
conds.Set(nidx, h_split_conds[nidx]);
150+
default_left.Set(nidx, h_default_left[nidx]);
125151

126152
auto in_weight = this->NodeWeight(nidx);
127153
auto weight_out = common::Span<float>(weights.GetArray())
@@ -157,8 +183,8 @@ void MultiTargetTree::SetLeaf(bst_node_t nidx, linalg::VectorView<float const> w
157183
CHECK(this->IsLeaf(nidx)) << "Collapsing a split node to leaf " << MTNotImplemented();
158184
auto const next_nidx = nidx + 1;
159185
CHECK_EQ(weight.Size(), this->NumTarget());
160-
CHECK_GE(weights_.size(), next_nidx * weight.Size());
161-
auto out_weight = common::Span<float>(weights_).subspan(nidx * weight.Size(), weight.Size());
186+
CHECK_GE(weights_.Size(), next_nidx * weight.Size());
187+
auto out_weight = weights_.HostSpan().subspan(nidx * weight.Size(), weight.Size());
162188
for (std::size_t i = 0; i < weight.Size(); ++i) {
163189
out_weight[i] = weight(i);
164190
}
@@ -169,39 +195,40 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
169195
linalg::VectorView<float const> left_weight,
170196
linalg::VectorView<float const> right_weight) {
171197
CHECK(this->IsLeaf(nidx));
172-
CHECK_GE(parent_.size(), 1);
173-
CHECK_EQ(parent_.size(), left_.size());
174-
CHECK_EQ(left_.size(), right_.size());
198+
CHECK_GE(parent_.Size(), 1);
199+
CHECK_EQ(parent_.Size(), left_.Size());
200+
CHECK_EQ(left_.Size(), right_.Size());
175201

176202
std::size_t n = param_->num_nodes + 2;
177203
CHECK_LT(split_idx, this->param_->num_feature);
178-
left_.resize(n, InvalidNodeId());
179-
right_.resize(n, InvalidNodeId());
180-
parent_.resize(n, InvalidNodeId());
204+
left_.Resize(n, InvalidNodeId());
205+
right_.Resize(n, InvalidNodeId());
206+
parent_.Resize(n, InvalidNodeId());
181207

182-
auto left_child = parent_.size() - 2;
183-
auto right_child = parent_.size() - 1;
208+
auto left_child = parent_.Size() - 2;
209+
auto right_child = parent_.Size() - 1;
184210

185-
left_[nidx] = left_child;
186-
right_[nidx] = right_child;
211+
left_.HostVector()[nidx] = left_child;
212+
right_.HostVector()[nidx] = right_child;
187213

214+
auto& h_parent = parent_.HostVector();
188215
if (nidx != 0) {
189-
CHECK_NE(parent_[nidx], InvalidNodeId());
216+
CHECK_NE(h_parent[nidx], InvalidNodeId());
190217
}
191218

192-
parent_[left_child] = nidx;
193-
parent_[right_child] = nidx;
219+
h_parent[left_child] = nidx;
220+
h_parent[right_child] = nidx;
194221

195-
split_index_.resize(n);
196-
split_index_[nidx] = split_idx;
222+
split_index_.Resize(n);
223+
split_index_.HostVector()[nidx] = split_idx;
197224

198-
split_conds_.resize(n, std::numeric_limits<float>::quiet_NaN());
199-
split_conds_[nidx] = split_cond;
225+
split_conds_.Resize(n, std::numeric_limits<float>::quiet_NaN());
226+
split_conds_.HostVector()[nidx] = split_cond;
200227

201-
default_left_.resize(n);
202-
default_left_[nidx] = static_cast<std::uint8_t>(default_left);
228+
default_left_.Resize(n);
229+
default_left_.HostVector()[nidx] = static_cast<std::uint8_t>(default_left);
203230

204-
weights_.resize(n * this->NumTarget());
231+
weights_.Resize(n * this->NumTarget());
205232
auto p_weight = this->NodeWeight(nidx);
206233
CHECK_EQ(p_weight.Size(), base_weight.Size());
207234
auto l_weight = this->NodeWeight(left_child);
@@ -217,5 +244,5 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
217244
}
218245

219246
bst_target_t MultiTargetTree::NumTarget() const { return param_->size_leaf_vector; }
220-
std::size_t MultiTargetTree::Size() const { return parent_.size(); }
247+
std::size_t MultiTargetTree::Size() const { return parent_.Size(); }
221248
} // namespace xgboost

0 commit comments

Comments
 (0)