Skip to content

Commit 067b704

Browse files
authored
[backport] Fix inference with categorical feature. (dmlc#8591) (dmlc#8602) (dmlc#8638)
* Fix inference with categorical feature. (dmlc#8591) * Fix windows build on buildkite. (dmlc#8602) * workaround.
1 parent 1a834b2 commit 067b704

File tree

7 files changed

+79
-31
lines changed

7 files changed

+79
-31
lines changed

doc/tutorials/categorical.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,11 @@ Miscellaneous
138138

139139
By default, XGBoost assumes input categories are integers starting from 0 till the number
140140
of categories :math:`[0, n\_categories)`. However, user might provide inputs with invalid
141-
values due to mistakes or missing values. It can be negative value, integer values that
142-
can not be accurately represented by 32-bit floating point, or values that are larger than
143-
actual number of unique categories. During training this is validated but for prediction
144-
it's treated as the same as missing value for performance reasons. Lastly, missing values
145-
are treated as the same as numerical features (using the learned split direction).
141+
values due to mistakes or missing values in training dataset. It can be negative value,
142+
integer values that can not be accurately represented by 32-bit floating point, or values
143+
that are larger than actual number of unique categories. During training this is
144+
validated but for prediction it's treated as the same as not-chosen category for
145+
performance reasons.
146146

147147

148148
**********

src/common/categorical.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,21 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
4848
return cat < 0 || cat >= kMaxCat;
4949
}
5050

51-
/* \brief Whether should it traverse to left branch of a tree.
51+
/**
52+
* \brief Whether should it traverse to left branch of a tree.
5253
*
53-
* For one hot split, go to left if it's NOT the matching category.
54+
* Go to left if it's NOT the matching category, which matches one-hot encoding.
5455
*/
55-
template <bool validate = true>
56-
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
56+
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat) {
5757
KCatBitField const s_cats(cats);
58-
// FIXME: Size() is not accurate since it represents the size of bit set instead of
59-
// actual number of categories.
60-
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
61-
return dft_left;
58+
if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
59+
return true;
6260
}
6361

6462
auto pos = KCatBitField::ToBitPos(cat);
63+
// If the input category is larger than the size of the bit field, it implies that the
64+
// category is not chosen. Otherwise the bit field would have the category instead of
65+
// being smaller than the category value.
6566
if (pos.int_pos >= cats.size()) {
6667
return true;
6768
}

src/common/partition_builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class PartitionBuilder {
144144
auto gidx = gidx_calc(ridx);
145145
bool go_left = default_left;
146146
if (gidx > -1) {
147-
go_left = Decision(node_cats, cut_values[gidx], default_left);
147+
go_left = Decision(node_cats, cut_values[gidx]);
148148
}
149149
return go_left;
150150
} else {
@@ -157,7 +157,7 @@ class PartitionBuilder {
157157
bool go_left = default_left;
158158
if (gidx > -1) {
159159
if (is_cat) {
160-
go_left = Decision(node_cats, cut_values[gidx], default_left);
160+
go_left = Decision(node_cats, cut_values[gidx]);
161161
} else {
162162
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
163163
}

src/predictor/predict_fn.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs
1818
if (has_categorical && common::IsCat(cats.split_type, nid)) {
1919
auto node_categories =
2020
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
21-
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
22-
? node.LeftChild()
23-
: node.RightChild();
21+
return common::Decision(node_categories, fvalue) ? node.LeftChild() : node.RightChild();
2422
} else {
2523
return node.LeftChild() + !(fvalue < node.SplitCond());
2624
}

src/tree/updater_gpu_hist.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,7 @@ struct GPUHistMakerDevice {
403403
go_left = data.split_node.DefaultLeft();
404404
} else {
405405
if (data.split_type == FeatureType::kCategorical) {
406-
go_left = common::Decision<false>(data.node_cats.Bits(), cut_value,
407-
data.split_node.DefaultLeft());
406+
go_left = common::Decision(data.node_cats.Bits(), cut_value);
408407
} else {
409408
go_left = cut_value <= data.split_node.SplitCond();
410409
}
@@ -481,7 +480,7 @@ struct GPUHistMakerDevice {
481480
if (common::IsCat(d_feature_types, position)) {
482481
auto node_cats = categories.subspan(categories_segments[position].beg,
483482
categories_segments[position].size);
484-
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
483+
go_left = common::Decision(node_cats, element);
485484
} else {
486485
go_left = element <= node.SplitCond();
487486
}

tests/cpp/common/test_categorical.cc

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
/*!
2-
* Copyright 2021 by XGBoost Contributors
2+
* Copyright 2021-2022 by XGBoost Contributors
33
*/
44
#include <gtest/gtest.h>
5+
#include <xgboost/json.h>
6+
#include <xgboost/learner.h>
57

68
#include <limits>
79

810
#include "../../../src/common/categorical.h"
11+
#include "../helpers.h"
912

1013
namespace xgboost {
1114
namespace common {
@@ -15,29 +18,76 @@ TEST(Categorical, Decision) {
1518

1619
ASSERT_TRUE(common::InvalidCat(a));
1720
std::vector<uint32_t> cats(256, 0);
18-
ASSERT_TRUE(Decision(cats, a, true));
21+
ASSERT_TRUE(Decision(cats, a));
1922

2023
// larger than size
2124
a = 256;
22-
ASSERT_TRUE(Decision(cats, a, true));
25+
ASSERT_TRUE(Decision(cats, a));
2326

2427
// negative
2528
a = -1;
26-
ASSERT_TRUE(Decision(cats, a, true));
29+
ASSERT_TRUE(Decision(cats, a));
2730

2831
CatBitField bits{cats};
2932
bits.Set(0);
3033
a = -0.5;
31-
ASSERT_TRUE(Decision(cats, a, true));
34+
ASSERT_TRUE(Decision(cats, a));
3235

3336
// round toward 0
3437
a = 0.5;
35-
ASSERT_FALSE(Decision(cats, a, true));
38+
ASSERT_FALSE(Decision(cats, a));
3639

3740
// valid
3841
a = 13;
3942
bits.Set(a);
40-
ASSERT_FALSE(Decision(bits.Bits(), a, true));
43+
ASSERT_FALSE(Decision(bits.Bits(), a));
44+
}
45+
46+
/**
47+
* Test for running inference with input category greater than the one stored in tree.
48+
*/
49+
TEST(Categorical, MinimalSet) {
50+
std::size_t constexpr kRows = 256, kCols = 1, kCat = 3;
51+
std::vector<FeatureType> types{FeatureType::kCategorical};
52+
auto Xy =
53+
RandomDataGenerator{kRows, kCols, 0.0}.Type(types).MaxCategory(kCat).GenerateDMatrix(true);
54+
55+
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
56+
learner->SetParam("max_depth", "1");
57+
learner->SetParam("tree_method", "hist");
58+
learner->Configure();
59+
learner->UpdateOneIter(0, Xy);
60+
61+
Json model{Object{}};
62+
learner->SaveModel(&model);
63+
auto tree = model["learner"]["gradient_booster"]["model"]["trees"][0];
64+
ASSERT_GE(get<I32Array const>(tree["categories"]).size(), 1);
65+
auto v = get<I32Array const>(tree["categories"])[0];
66+
67+
HostDeviceVector<float> predt;
68+
{
69+
std::vector<float> data{static_cast<float>(kCat),
70+
static_cast<float>(kCat + 1), 32.0f, 33.0f, 34.0f};
71+
auto test = GetDMatrixFromData(data, data.size(), kCols);
72+
learner->Predict(test, false, &predt, 0, 0, false, /*pred_leaf=*/true);
73+
ASSERT_EQ(predt.Size(), data.size());
74+
auto const& h_predt = predt.ConstHostSpan();
75+
for (auto v : h_predt) {
76+
ASSERT_EQ(v, 1); // left child of root node
77+
}
78+
}
79+
80+
{
81+
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
82+
learner->LoadModel(model);
83+
std::vector<float> data = {static_cast<float>(v)};
84+
auto test = GetDMatrixFromData(data, data.size(), kCols);
85+
learner->Predict(test, false, &predt, 0, 0, false, /*pred_leaf=*/true);
86+
auto const& h_predt = predt.ConstHostSpan();
87+
for (auto v : h_predt) {
88+
ASSERT_EQ(v, 2); // right child of root node
89+
}
90+
}
4191
}
4292
} // namespace common
4393
} // namespace xgboost

tests/python/test_with_sklearn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,9 +1029,9 @@ def test_pandas_input():
10291029

10301030
clf_isotonic = CalibratedClassifierCV(model, cv="prefit", method="isotonic")
10311031
clf_isotonic.fit(train, target)
1032-
assert isinstance(
1033-
clf_isotonic.calibrated_classifiers_[0].estimator, xgb.XGBClassifier
1034-
)
1032+
clf = clf_isotonic.calibrated_classifiers_[0]
1033+
est = clf.estimator if hasattr(clf, "estimator") else clf.base_estimator
1034+
assert isinstance(est, xgb.XGBClassifier)
10351035
np.testing.assert_allclose(np.array(clf_isotonic.classes_), np.array([0, 1]))
10361036

10371037

0 commit comments

Comments
 (0)