1
1
/* *
2
- * Copyright 2023 by XGBoost Contributors
2
+ * Copyright 2023-2025, XGBoost Contributors
3
3
*/
4
4
#include " xgboost/multi_target_tree_model.h"
5
5
6
- #include < algorithm> // for copy_n
7
- #include < cstddef> // for size_t
8
- #include < cstdint> // for int32_t, uint8_t
9
- #include < limits> // for numeric_limits
10
- #include < string_view> // for string_view
11
- #include < utility> // for move
12
- #include < vector> // for vector
13
-
14
- #include " io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
15
- #include " xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
16
- #include " xgboost/json.h" // for Json, get, Object, Number, Integer, ...
6
+ #include < algorithm> // for copy_n
7
+ #include < cstddef> // for size_t
8
+ #include < cstdint> // for int32_t, uint8_t
9
+ #include < limits> // for numeric_limits
10
+ #include < string_view> // for string_view
11
+ #include < utility> // for move
12
+ #include < vector> // for vector
13
+
14
+ #include " io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
15
+ #include " xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
16
+ #include " xgboost/json.h" // for Json, get, Object, Number, Integer, ...
17
17
#include " xgboost/logging.h"
18
18
#include " xgboost/tree_model.h" // for TreeParam
19
19
@@ -30,27 +30,47 @@ MultiTargetTree::MultiTargetTree(TreeParam const* param)
30
30
CHECK_GT (param_->size_leaf_vector , 1 );
31
31
}
32
32
33
+ MultiTargetTree::MultiTargetTree (MultiTargetTree const & that)
34
+ : param_{that.param_ },
35
+ left_ (that.left_.Size(), 0 , that.left_.Device()),
36
+ right_ (that.right_.Size(), 0 , that.right_.Device()),
37
+ parent_ (that.parent_.Size(), 0 , that.parent_.Device()),
38
+ split_index_ (that.split_index_.Size(), 0 , that.split_index_.Device()),
39
+ default_left_ (that.default_left_.Size(), 0 , that.default_left_.Device()),
40
+ split_conds_ (that.split_conds_.Size(), 0 , that.split_conds_.Device()),
41
+ weights_ (that.weights_.Size(), 0 , that.weights_.Device()) {
42
+ this ->left_ .Copy (that.left_ );
43
+ this ->right_ .Copy (that.right_ );
44
+ this ->parent_ .Copy (that.parent_ );
45
+ this ->split_index_ .Copy (that.split_index_ );
46
+ this ->default_left_ .Copy (that.default_left_ );
47
+ this ->split_conds_ .Copy (that.split_conds_ );
48
+ this ->weights_ .Copy (that.weights_ );
49
+ }
50
+
33
51
template <bool typed, bool feature_is_64>
34
- void LoadModelImpl (Json const & in, std::vector<float >* p_weights, std::vector<bst_node_t >* p_lefts,
35
- std::vector<bst_node_t >* p_rights, std::vector<bst_node_t >* p_parents,
36
- std::vector<float >* p_conds, std::vector<bst_feature_t >* p_fidx,
37
- std::vector<std::uint8_t >* p_dft_left) {
52
+ void LoadModelImpl (Json const & in, HostDeviceVector<float >* p_weights,
53
+ HostDeviceVector<bst_node_t >* p_lefts, HostDeviceVector<bst_node_t >* p_rights,
54
+ HostDeviceVector<bst_node_t >* p_parents, HostDeviceVector<float >* p_conds,
55
+ HostDeviceVector<bst_feature_t >* p_fidx,
56
+ HostDeviceVector<std::uint8_t >* p_dft_left) {
38
57
namespace tf = tree_field;
39
58
40
- auto get_float = [&](std::string_view name, std::vector <float >* p_out) {
59
+ auto get_float = [&](std::string_view name, HostDeviceVector <float >* p_out) {
41
60
auto & values = get<FloatArrayT<typed>>(get<Object const >(in).find (name)->second );
42
61
auto & out = *p_out;
43
- out.resize (values.size ());
62
+ out.Resize (values.size ());
63
+ auto & h_out = out.HostVector ();
44
64
for (std::size_t i = 0 ; i < values.size (); ++i) {
45
- out [i] = GetElem<Number>(values, i);
65
+ h_out [i] = GetElem<Number>(values, i);
46
66
}
47
67
};
48
68
get_float (tf::kBaseWeight , p_weights);
49
69
get_float (tf::kSplitCond , p_conds);
50
70
51
- auto get_nidx = [&](std::string_view name, std::vector <bst_node_t >* p_nidx) {
71
+ auto get_nidx = [&](std::string_view name, HostDeviceVector <bst_node_t >* p_nidx) {
52
72
auto & nidx = get<I32ArrayT<typed>>(get<Object const >(in).find (name)->second );
53
- auto & out_nidx = * p_nidx;
73
+ auto & out_nidx = p_nidx-> HostVector () ;
54
74
out_nidx.resize (nidx.size ());
55
75
for (std::size_t i = 0 ; i < nidx.size (); ++i) {
56
76
out_nidx[i] = GetElem<Integer>(nidx, i);
@@ -61,15 +81,15 @@ void LoadModelImpl(Json const& in, std::vector<float>* p_weights, std::vector<bs
61
81
get_nidx (tf::kParent , p_parents);
62
82
63
83
auto const & splits = get<IndexArrayT<typed, feature_is_64> const >(in[tf::kSplitIdx ]);
64
- p_fidx->resize (splits.size ());
65
- auto & out_fidx = * p_fidx;
84
+ p_fidx->Resize (splits.size ());
85
+ auto & out_fidx = p_fidx-> HostVector () ;
66
86
for (std::size_t i = 0 ; i < splits.size (); ++i) {
67
87
out_fidx[i] = GetElem<Integer>(splits, i);
68
88
}
69
89
70
90
auto const & dft_left = get<U8ArrayT<typed> const >(in[tf::kDftLeft ]);
71
- auto & out_dft_l = * p_dft_left;
72
- out_dft_l. resize (dft_left. size () );
91
+ p_dft_left-> Resize (dft_left. size ()) ;
92
+ auto & out_dft_l = p_dft_left-> HostVector ( );
73
93
for (std::size_t i = 0 ; i < dft_left.size (); ++i) {
74
94
out_dft_l[i] = GetElem<Boolean>(dft_left, i);
75
95
}
@@ -109,19 +129,25 @@ void MultiTargetTree::SaveModel(Json* p_out) const {
109
129
U8Array default_left (n_nodes);
110
130
F32Array weights (n_nodes * this ->NumTarget ());
111
131
132
+ auto const & h_left = this ->left_ .ConstHostVector ();
133
+ auto const & h_right = this ->right_ .ConstHostVector ();
134
+ auto const & h_parent = this ->parent_ .ConstHostVector ();
135
+ auto const & h_split_index = this ->split_index_ .ConstHostVector ();
136
+ auto const & h_split_conds = this ->split_conds_ .ConstHostVector ();
137
+ auto const & h_default_left = this ->default_left_ .ConstHostVector ();
112
138
auto save_tree = [&](auto * p_indices_array) {
113
139
auto & indices_array = *p_indices_array;
114
140
for (bst_node_t nidx = 0 ; nidx < n_nodes; ++nidx) {
115
- CHECK_LT (nidx, left_.size ());
116
- lefts.Set (nidx, left_ [nidx]);
117
- CHECK_LT (nidx, right_.size ());
118
- rights.Set (nidx, right_ [nidx]);
119
- CHECK_LT (nidx, parent_.size ());
120
- parents.Set (nidx, parent_ [nidx]);
121
- CHECK_LT (nidx, split_index_.size ());
122
- indices_array.Set (nidx, split_index_ [nidx]);
123
- conds.Set (nidx, split_conds_ [nidx]);
124
- default_left.Set (nidx, default_left_ [nidx]);
141
+ CHECK_LT (nidx, left_.Size ());
142
+ lefts.Set (nidx, h_left [nidx]);
143
+ CHECK_LT (nidx, right_.Size ());
144
+ rights.Set (nidx, h_right [nidx]);
145
+ CHECK_LT (nidx, parent_.Size ());
146
+ parents.Set (nidx, h_parent [nidx]);
147
+ CHECK_LT (nidx, split_index_.Size ());
148
+ indices_array.Set (nidx, h_split_index [nidx]);
149
+ conds.Set (nidx, h_split_conds [nidx]);
150
+ default_left.Set (nidx, h_default_left [nidx]);
125
151
126
152
auto in_weight = this ->NodeWeight (nidx);
127
153
auto weight_out = common::Span<float >(weights.GetArray ())
@@ -157,8 +183,8 @@ void MultiTargetTree::SetLeaf(bst_node_t nidx, linalg::VectorView<float const> w
157
183
CHECK (this ->IsLeaf (nidx)) << " Collapsing a split node to leaf " << MTNotImplemented ();
158
184
auto const next_nidx = nidx + 1 ;
159
185
CHECK_EQ (weight.Size (), this ->NumTarget ());
160
- CHECK_GE (weights_.size (), next_nidx * weight.Size ());
161
- auto out_weight = common::Span< float >( weights_).subspan (nidx * weight.Size (), weight.Size ());
186
+ CHECK_GE (weights_.Size (), next_nidx * weight.Size ());
187
+ auto out_weight = weights_. HostSpan ( ).subspan (nidx * weight.Size (), weight.Size ());
162
188
for (std::size_t i = 0 ; i < weight.Size (); ++i) {
163
189
out_weight[i] = weight (i);
164
190
}
@@ -169,39 +195,40 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
169
195
linalg::VectorView<float const > left_weight,
170
196
linalg::VectorView<float const > right_weight) {
171
197
CHECK (this ->IsLeaf (nidx));
172
- CHECK_GE (parent_.size (), 1 );
173
- CHECK_EQ (parent_.size (), left_.size ());
174
- CHECK_EQ (left_.size (), right_.size ());
198
+ CHECK_GE (parent_.Size (), 1 );
199
+ CHECK_EQ (parent_.Size (), left_.Size ());
200
+ CHECK_EQ (left_.Size (), right_.Size ());
175
201
176
202
std::size_t n = param_->num_nodes + 2 ;
177
203
CHECK_LT (split_idx, this ->param_ ->num_feature );
178
- left_.resize (n, InvalidNodeId ());
179
- right_.resize (n, InvalidNodeId ());
180
- parent_.resize (n, InvalidNodeId ());
204
+ left_.Resize (n, InvalidNodeId ());
205
+ right_.Resize (n, InvalidNodeId ());
206
+ parent_.Resize (n, InvalidNodeId ());
181
207
182
- auto left_child = parent_.size () - 2 ;
183
- auto right_child = parent_.size () - 1 ;
208
+ auto left_child = parent_.Size () - 2 ;
209
+ auto right_child = parent_.Size () - 1 ;
184
210
185
- left_[nidx] = left_child;
186
- right_[nidx] = right_child;
211
+ left_. HostVector () [nidx] = left_child;
212
+ right_. HostVector () [nidx] = right_child;
187
213
214
+ auto & h_parent = parent_.HostVector ();
188
215
if (nidx != 0 ) {
189
- CHECK_NE (parent_ [nidx], InvalidNodeId ());
216
+ CHECK_NE (h_parent [nidx], InvalidNodeId ());
190
217
}
191
218
192
- parent_ [left_child] = nidx;
193
- parent_ [right_child] = nidx;
219
+ h_parent [left_child] = nidx;
220
+ h_parent [right_child] = nidx;
194
221
195
- split_index_.resize (n);
196
- split_index_[nidx] = split_idx;
222
+ split_index_.Resize (n);
223
+ split_index_. HostVector () [nidx] = split_idx;
197
224
198
- split_conds_.resize (n, std::numeric_limits<float >::quiet_NaN ());
199
- split_conds_[nidx] = split_cond;
225
+ split_conds_.Resize (n, std::numeric_limits<float >::quiet_NaN ());
226
+ split_conds_. HostVector () [nidx] = split_cond;
200
227
201
- default_left_.resize (n);
202
- default_left_[nidx] = static_cast <std::uint8_t >(default_left);
228
+ default_left_.Resize (n);
229
+ default_left_. HostVector () [nidx] = static_cast <std::uint8_t >(default_left);
203
230
204
- weights_.resize (n * this ->NumTarget ());
231
+ weights_.Resize (n * this ->NumTarget ());
205
232
auto p_weight = this ->NodeWeight (nidx);
206
233
CHECK_EQ (p_weight.Size (), base_weight.Size ());
207
234
auto l_weight = this ->NodeWeight (left_child);
@@ -217,5 +244,5 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
217
244
}
218
245
219
246
bst_target_t MultiTargetTree::NumTarget () const { return param_->size_leaf_vector ; }
220
- std::size_t MultiTargetTree::Size () const { return parent_.size (); }
247
+ std::size_t MultiTargetTree::Size () const { return parent_.Size (); }
221
248
} // namespace xgboost
0 commit comments