Skip to content

Commit 9b88495

Browse files
authored
[multi] Implement weight feature importance. (dmlc#10700)
1 parent 402e783 commit 9b88495

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

src/gbm/gbtree.h

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,11 @@ class GBTree : public GradientBooster {
236236
auto add_score = [&](auto fn) {
237237
for (auto idx : trees) {
238238
CHECK_LE(idx, total_n_trees) << "Invalid tree index.";
239-
auto const& p_tree = model_.trees[idx];
240-
p_tree->WalkTree([&](bst_node_t nidx) {
241-
auto const& node = (*p_tree)[nidx];
242-
if (!node.IsLeaf()) {
243-
split_counts[node.SplitIndex()]++;
244-
fn(p_tree, nidx, node.SplitIndex());
239+
auto const& tree = *model_.trees[idx];
240+
tree.WalkTree([&](bst_node_t nidx) {
241+
if (!tree.IsLeaf(nidx)) {
242+
split_counts[tree.SplitIndex(nidx)]++;
243+
fn(tree, nidx, tree.SplitIndex(nidx));
245244
}
246245
return true;
247246
});
@@ -253,12 +252,18 @@ class GBTree : public GradientBooster {
253252
gain_map[split] = split_counts[split];
254253
});
255254
} else if (importance_type == "gain" || importance_type == "total_gain") {
256-
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
257-
gain_map[split] += p_tree->Stat(nidx).loss_chg;
255+
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
256+
LOG(FATAL) << "gain/total_gain " << MTNotImplemented();
257+
}
258+
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
259+
gain_map[split] += tree.Stat(nidx).loss_chg;
258260
});
259261
} else if (importance_type == "cover" || importance_type == "total_cover") {
260-
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
261-
gain_map[split] += p_tree->Stat(nidx).sum_hess;
262+
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
263+
LOG(FATAL) << "cover/total_cover " << MTNotImplemented();
264+
}
265+
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
266+
gain_map[split] += tree.Stat(nidx).sum_hess;
262267
});
263268
} else {
264269
LOG(FATAL)

tests/python/test_with_sklearn.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,36 @@ def test_feature_importances_weight():
336336
cls.feature_importances_
337337

338338

339+
def test_feature_importances_weight_vector_leaf() -> None:
340+
from sklearn.datasets import make_multilabel_classification
341+
342+
X, y = make_multilabel_classification(random_state=1994)
343+
with pytest.raises(ValueError, match="gain/total_gain"):
344+
clf = xgb.XGBClassifier(multi_strategy="multi_output_tree")
345+
clf.fit(X, y)
346+
clf.feature_importances_
347+
348+
with pytest.raises(ValueError, match="cover/total_cover"):
349+
clf = xgb.XGBClassifier(
350+
multi_strategy="multi_output_tree", importance_type="cover"
351+
)
352+
clf.fit(X, y)
353+
clf.feature_importances_
354+
355+
clf = xgb.XGBClassifier(
356+
multi_strategy="multi_output_tree",
357+
importance_type="weight",
358+
colsample_bynode=0.2,
359+
)
360+
clf.fit(X, y, feature_weights=np.arange(0, X.shape[1]))
361+
fi = clf.feature_importances_
362+
assert fi[0] == 0.0
363+
assert fi[-1] > fi[1] * 5
364+
365+
w = np.polynomial.Polynomial.fit(np.arange(0, X.shape[1]), fi, deg=1)
366+
assert w.coef[1] > 0.03
367+
368+
339369
@pytest.mark.skipif(**tm.no_pandas())
340370
def test_feature_importances_gain():
341371
from sklearn.datasets import load_digits

0 commit comments

Comments
 (0)