@@ -236,12 +236,11 @@ class GBTree : public GradientBooster {
236
236
auto add_score = [&](auto fn) {
237
237
for (auto idx : trees) {
238
238
CHECK_LE (idx, total_n_trees) << " Invalid tree index." ;
239
- auto const & p_tree = model_.trees [idx];
240
- p_tree->WalkTree ([&](bst_node_t nidx) {
241
- auto const & node = (*p_tree)[nidx];
242
- if (!node.IsLeaf ()) {
243
- split_counts[node.SplitIndex ()]++;
244
- fn (p_tree, nidx, node.SplitIndex ());
239
+ auto const & tree = *model_.trees [idx];
240
+ tree.WalkTree ([&](bst_node_t nidx) {
241
+ if (!tree.IsLeaf (nidx)) {
242
+ split_counts[tree.SplitIndex (nidx)]++;
243
+ fn (tree, nidx, tree.SplitIndex (nidx));
245
244
}
246
245
return true ;
247
246
});
@@ -253,12 +252,18 @@ class GBTree : public GradientBooster {
253
252
gain_map[split] = split_counts[split];
254
253
});
255
254
} else if (importance_type == " gain" || importance_type == " total_gain" ) {
256
- add_score ([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
257
- gain_map[split] += p_tree->Stat (nidx).loss_chg ;
255
+ if (!model_.trees .empty () && model_.trees .front ()->IsMultiTarget ()) {
256
+ LOG (FATAL) << " gain/total_gain " << MTNotImplemented ();
257
+ }
258
+ add_score ([&](auto const & tree, bst_node_t nidx, bst_feature_t split) {
259
+ gain_map[split] += tree.Stat (nidx).loss_chg ;
258
260
});
259
261
} else if (importance_type == " cover" || importance_type == " total_cover" ) {
260
- add_score ([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
261
- gain_map[split] += p_tree->Stat (nidx).sum_hess ;
262
+ if (!model_.trees .empty () && model_.trees .front ()->IsMultiTarget ()) {
263
+ LOG (FATAL) << " cover/total_cover " << MTNotImplemented ();
264
+ }
265
+ add_score ([&](auto const & tree, bst_node_t nidx, bst_feature_t split) {
266
+ gain_map[split] += tree.Stat (nidx).sum_hess ;
262
267
});
263
268
} else {
264
269
LOG (FATAL)
0 commit comments