@@ -79,6 +79,162 @@ void HistUpdater<GradientSumT>::BuildLocalHistograms(
79
79
builder_monitor_.Stop (" BuildLocalHistograms" );
80
80
}
81
81
82
+ template <typename GradientSumT>
83
+ void HistUpdater<GradientSumT>::BuildNodeStats(
84
+ const common::GHistIndexMatrix &gmat,
85
+ RegTree *p_tree,
86
+ const USMVector<GradientPair, MemoryType::on_device> &gpair) {
87
+ builder_monitor_.Start (" BuildNodeStats" );
88
+ for (auto const & entry : qexpand_depth_wise_) {
89
+ int nid = entry.nid ;
90
+ this ->InitNewNode (nid, gmat, gpair, *p_tree);
91
+ // add constraints
92
+ if (!(*p_tree)[nid].IsLeftChild () && !(*p_tree)[nid].IsRoot ()) {
93
+ // it's a right child
94
+ auto parent_id = (*p_tree)[nid].Parent ();
95
+ auto left_sibling_id = (*p_tree)[parent_id].LeftChild ();
96
+ auto parent_split_feature_id = snode_host_[parent_id].best .SplitIndex ();
97
+ tree_evaluator_.AddSplit (
98
+ parent_id, left_sibling_id, nid, parent_split_feature_id,
99
+ snode_host_[left_sibling_id].weight , snode_host_[nid].weight );
100
+ interaction_constraints_.Split (parent_id, parent_split_feature_id,
101
+ left_sibling_id, nid);
102
+ }
103
+ }
104
+ builder_monitor_.Stop (" BuildNodeStats" );
105
+ }
106
+
107
+ template <typename GradientSumT>
108
+ void HistUpdater<GradientSumT>::AddSplitsToTree(
109
+ const common::GHistIndexMatrix &gmat,
110
+ RegTree *p_tree,
111
+ int *num_leaves,
112
+ int depth,
113
+ std::vector<ExpandEntry>* nodes_for_apply_split,
114
+ std::vector<ExpandEntry>* temp_qexpand_depth) {
115
+ builder_monitor_.Start (" AddSplitsToTree" );
116
+ auto evaluator = tree_evaluator_.GetEvaluator ();
117
+ for (auto const & entry : qexpand_depth_wise_) {
118
+ const auto lr = param_.learning_rate ;
119
+ int nid = entry.nid ;
120
+
121
+ if (snode_host_[nid].best .loss_chg < kRtEps ||
122
+ (param_.max_depth > 0 && depth == param_.max_depth ) ||
123
+ (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves )) {
124
+ (*p_tree)[nid].SetLeaf (snode_host_[nid].weight * lr);
125
+ } else {
126
+ nodes_for_apply_split->push_back (entry);
127
+
128
+ NodeEntry<GradientSumT>& e = snode_host_[nid];
129
+ bst_float left_leaf_weight =
130
+ evaluator.CalcWeight (nid, GradStats<GradientSumT>{e.best .left_sum }) * lr;
131
+ bst_float right_leaf_weight =
132
+ evaluator.CalcWeight (nid, GradStats<GradientSumT>{e.best .right_sum }) * lr;
133
+ p_tree->ExpandNode (nid, e.best .SplitIndex (), e.best .split_value ,
134
+ e.best .DefaultLeft (), e.weight , left_leaf_weight,
135
+ right_leaf_weight, e.best .loss_chg , e.stats .GetHess (),
136
+ e.best .left_sum .GetHess (), e.best .right_sum .GetHess ());
137
+
138
+ int left_id = (*p_tree)[nid].LeftChild ();
139
+ int right_id = (*p_tree)[nid].RightChild ();
140
+ temp_qexpand_depth->push_back (ExpandEntry (left_id, p_tree->GetDepth (left_id)));
141
+ temp_qexpand_depth->push_back (ExpandEntry (right_id, p_tree->GetDepth (right_id)));
142
+ // - 1 parent + 2 new children
143
+ (*num_leaves)++;
144
+ }
145
+ }
146
+ builder_monitor_.Stop (" AddSplitsToTree" );
147
+ }
148
+
149
+
150
+ template <typename GradientSumT>
151
+ void HistUpdater<GradientSumT>::EvaluateAndApplySplits(
152
+ const common::GHistIndexMatrix &gmat,
153
+ RegTree *p_tree,
154
+ int *num_leaves,
155
+ int depth,
156
+ std::vector<ExpandEntry> *temp_qexpand_depth) {
157
+ EvaluateSplits (qexpand_depth_wise_, gmat, *p_tree);
158
+
159
+ std::vector<ExpandEntry> nodes_for_apply_split;
160
+ AddSplitsToTree (gmat, p_tree, num_leaves, depth,
161
+ &nodes_for_apply_split, temp_qexpand_depth);
162
+ ApplySplit (nodes_for_apply_split, gmat, p_tree);
163
+ }
164
+
165
+ // Split nodes to 2 sets depending on amount of rows in each node
166
+ // Histograms for small nodes will be built explicitly
167
+ // Histograms for big nodes will be built by 'Subtraction Trick'
168
+ // Exception: in distributed setting, we always build the histogram for the left child node
169
+ // and use 'Subtraction Trick' to built the histogram for the right child node.
170
+ // This ensures that the workers operate on the same set of tree nodes.
171
+ template <typename GradientSumT>
172
+ void HistUpdater<GradientSumT>::SplitSiblings(
173
+ const std::vector<ExpandEntry> &nodes,
174
+ std::vector<ExpandEntry> *small_siblings,
175
+ std::vector<ExpandEntry> *big_siblings,
176
+ RegTree *p_tree) {
177
+ builder_monitor_.Start (" SplitSiblings" );
178
+ for (auto const & entry : nodes) {
179
+ int nid = entry.nid ;
180
+ RegTree::Node &node = (*p_tree)[nid];
181
+ if (node.IsRoot ()) {
182
+ small_siblings->push_back (entry);
183
+ } else {
184
+ const int32_t left_id = (*p_tree)[node.Parent ()].LeftChild ();
185
+ const int32_t right_id = (*p_tree)[node.Parent ()].RightChild ();
186
+
187
+ if (nid == left_id && row_set_collection_[left_id ].Size () <
188
+ row_set_collection_[right_id].Size ()) {
189
+ small_siblings->push_back (entry);
190
+ } else if (nid == right_id && row_set_collection_[right_id].Size () <=
191
+ row_set_collection_[left_id ].Size ()) {
192
+ small_siblings->push_back (entry);
193
+ } else {
194
+ big_siblings->push_back (entry);
195
+ }
196
+ }
197
+ }
198
+ builder_monitor_.Stop (" SplitSiblings" );
199
+ }
200
+
201
+ template <typename GradientSumT>
202
+ void HistUpdater<GradientSumT>::ExpandWithDepthWise(
203
+ const common::GHistIndexMatrix &gmat,
204
+ RegTree *p_tree,
205
+ const USMVector<GradientPair, MemoryType::on_device> &gpair) {
206
+ int num_leaves = 0 ;
207
+
208
+ // in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway
209
+ qexpand_depth_wise_.emplace_back (ExpandEntry::kRootNid ,
210
+ p_tree->GetDepth (ExpandEntry::kRootNid ));
211
+ ++num_leaves;
212
+ for (int depth = 0 ; depth < param_.max_depth + 1 ; depth++) {
213
+ std::vector<int > sync_ids;
214
+ std::vector<ExpandEntry> temp_qexpand_depth;
215
+ SplitSiblings (qexpand_depth_wise_, &nodes_for_explicit_hist_build_,
216
+ &nodes_for_subtraction_trick_, p_tree);
217
+ hist_rows_adder_->AddHistRows (this , &sync_ids, p_tree);
218
+ BuildLocalHistograms (gmat, p_tree, gpair);
219
+ hist_synchronizer_->SyncHistograms (this , sync_ids, p_tree);
220
+ BuildNodeStats (gmat, p_tree, gpair);
221
+
222
+ EvaluateAndApplySplits (gmat, p_tree, &num_leaves, depth,
223
+ &temp_qexpand_depth);
224
+
225
+ // clean up
226
+ qexpand_depth_wise_.clear ();
227
+ nodes_for_subtraction_trick_.clear ();
228
+ nodes_for_explicit_hist_build_.clear ();
229
+ if (temp_qexpand_depth.empty ()) {
230
+ break ;
231
+ } else {
232
+ qexpand_depth_wise_ = temp_qexpand_depth;
233
+ temp_qexpand_depth.clear ();
234
+ }
235
+ }
236
+ }
237
+
82
238
template <typename GradientSumT>
83
239
void HistUpdater<GradientSumT>::ExpandWithLossGuide(
84
240
const common::GHistIndexMatrix& gmat,
@@ -326,7 +482,7 @@ void HistUpdater<GradientSumT>::InitData(
326
482
if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide ) {
327
483
qexpand_loss_guided_.reset (new ExpandQueue (LossGuide));
328
484
} else {
329
- LOG (WARNING) << " Depth-wise building is not yet implemented " ;
485
+ qexpand_depth_wise_. clear () ;
330
486
}
331
487
}
332
488
builder_monitor_.Stop (" InitData" );
0 commit comments