Skip to content

Commit 8ab3401

Browse files
authored
Investigate a bug in xgboost (#520)
Infer num_class and n_estimators through tree_info add target_opset to test cases Adjust opset version for hummingbird due to torch opset version support Signed-off-by: BowenBao <[email protected]>
1 parent adc41ee commit 8ab3401

File tree

4 files changed

+73
-33
lines changed

4 files changed

+73
-33
lines changed

onnxmltools/convert/xgboost/_parse.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,23 @@ def _append_covers(node):
2727

2828

2929
def _get_attributes(booster):
30+
# num_class
31+
state = booster.__getstate__()
32+
bstate = bytes(state['handle'])
33+
reg = re.compile(b'("tree_info":\\[[0-9,]*\\])')
34+
objs = list(set(reg.findall(bstate)))
35+
assert len(objs) == 1, 'Missing required property "tree_info".'
36+
tree_info = json.loads("{{{}}}".format(objs[0].decode('ascii')))['tree_info']
37+
num_class = len(set(tree_info))
38+
3039
atts = booster.attributes()
31-
ntrees = booster.best_ntree_limit
3240
dp = booster.get_dump(dump_format='json', with_stats=True)
3341
res = [json.loads(d) for d in dp]
3442
trees = len(res)
43+
try:
44+
ntrees = booster.best_ntree_limit
45+
except AttributeError:
46+
ntrees = trees // num_class if num_class > 0 else trees
3547
kwargs = atts.copy()
3648
kwargs['feature_names'] = booster.feature_names
3749
kwargs['n_estimators'] = ntrees
@@ -43,34 +55,22 @@ def _get_attributes(booster):
4355

4456
if all(map(lambda x: int(x) == x, set(covs))):
4557
# regression
58+
kwargs['num_target'] = num_class
4659
kwargs['num_class'] = 0
47-
if trees > ntrees > 0:
48-
kwargs['num_target'] = trees // ntrees
49-
kwargs["objective"] = "reg:squarederror"
50-
else:
51-
kwargs['num_target'] = 1
52-
kwargs["objective"] = "reg:squarederror"
60+
kwargs["objective"] = "reg:squarederror"
5361
else:
5462
# classification
55-
kwargs['num_target'] = 0
56-
if trees > ntrees > 0:
57-
state = booster.__getstate__()
58-
bstate = bytes(state['handle'])
63+
kwargs['num_class'] = num_class
64+
if num_class != 1:
5965
reg = re.compile(b'(multi:[a-z]{1,15})')
6066
objs = list(set(reg.findall(bstate)))
61-
if len(objs) != 1:
62-
if '"name":"binary:logistic"' in str(bstate):
63-
kwargs['num_class'] = 1
64-
kwargs["objective"] = "binary:logistic"
65-
else:
66-
raise RuntimeError(
67-
"Unable to guess objective in %r (trees=%r, ntrees=%r)"
68-
"." % (objs, trees, ntrees))
69-
else:
70-
kwargs['num_class'] = trees // ntrees
67+
if len(objs) == 1:
7168
kwargs["objective"] = objs[0].decode('ascii')
69+
else:
70+
raise RuntimeError(
71+
"Unable to guess objective in %r (trees=%r, ntrees=%r, num_class=%r)"
72+
"." % (objs, trees, ntrees, kwargs['num_class']))
7273
else:
73-
kwargs['num_class'] = 1
7474
kwargs["objective"] = "binary:logistic"
7575

7676
if 'base_score' not in kwargs:

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def common_members(xgb_node, inputs):
3838
booster = xgb_node.get_booster()
3939
# The json format was available in October 2017.
4040
# XGBoost 0.7 was the first version released with it.
41-
js_tree_list = booster.get_dump(with_stats=True, dump_format = 'json')
41+
js_tree_list = booster.get_dump(with_stats=True, dump_format='json')
4242
js_trees = [json.loads(s) for s in js_tree_list]
4343
return objective, base_score, js_trees
4444

tests/hummingbirdml/test_LightGbmTreeEnsembleConverters_hummingbird.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121

2222
TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version())
23+
# PyTorch 1.8.1 supports up to opset version 13.
24+
HUMMINGBIRD_TARGET_OPSET = min(TARGET_OPSET, 13)
2325

2426

2527
class TestLightGbmTreeEnsembleModelsHummingBird(unittest.TestCase):
@@ -46,7 +48,7 @@ def test_lightgbm_booster_classifier(self):
4648
data)
4749
model_onnx, prefix = convert_model(model, 'tree-based classifier',
4850
[('input', FloatTensorType([None, 2]))], without_onnx_ml=True,
49-
target_opset=TARGET_OPSET,
51+
target_opset=HUMMINGBIRD_TARGET_OPSET,
5052
zipmap=False)
5153
dump_data_and_model(X, model, model_onnx,
5254
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
@@ -63,16 +65,16 @@ def test_lightgbm_booster_classifier_zipmap(self):
6365
data)
6466
model_onnx, prefix = convert_model(model, 'tree-based classifier',
6567
[('input', FloatTensorType([None, 2]))], without_onnx_ml=False,
66-
target_opset=TARGET_OPSET)
68+
target_opset=HUMMINGBIRD_TARGET_OPSET)
6769
assert "zipmap" in str(model_onnx).lower()
6870
with self.assertRaises(NotImplementedError):
6971
convert_model(model, 'tree-based classifier',
7072
[('input', FloatTensorType([None, 2]))], without_onnx_ml=True,
71-
target_opset=TARGET_OPSET)
72-
73+
target_opset=HUMMINGBIRD_TARGET_OPSET)
74+
7375
model_onnx, prefix = convert_model(model, 'tree-based classifier',
7476
[('input', FloatTensorType([None, 2]))], without_onnx_ml=True,
75-
target_opset=TARGET_OPSET, zipmap=False)
77+
target_opset=HUMMINGBIRD_TARGET_OPSET, zipmap=False)
7678
dump_data_and_model(X, model, model_onnx,
7779
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
7880
basename=prefix + "BoosterBin" + model.__class__.__name__)
@@ -88,7 +90,7 @@ def test_lightgbm_booster_multi_classifier(self):
8890
data)
8991
model_onnx, prefix = convert_model(model, 'tree-based classifier',
9092
[('input', FloatTensorType([None, 2]))], without_onnx_ml=True,
91-
target_opset=TARGET_OPSET, zipmap=False)
93+
target_opset=HUMMINGBIRD_TARGET_OPSET, zipmap=False)
9294
dump_data_and_model(X, model, model_onnx,
9395
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
9496
basename=prefix + "BoosterBin" + model.__class__.__name__)
@@ -108,7 +110,7 @@ def test_lightgbm_booster_regressor(self):
108110
data)
109111
model_onnx, prefix = convert_model(model, 'tree-based binary regressor',
110112
[('input', FloatTensorType([None, 2]))], without_onnx_ml=True,
111-
target_opset=TARGET_OPSET, zipmap=False)
113+
target_opset=HUMMINGBIRD_TARGET_OPSET, zipmap=False)
112114
dump_data_and_model(X, model, model_onnx,
113115
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.0.0')",
114116
basename=prefix + "BoosterBin" + model.__class__.__name__)
@@ -203,7 +205,7 @@ def _test_lightgbm_booster_regressor(self):
203205
y = [0, 1, 1.1]
204206
data = lightgbm.Dataset(X, label=y)
205207
model = lightgbm.train(
206-
{"boosting_type": "gbdt", "objective": "regression", "n_estimators": 3,
208+
{"boosting_type": "gbdt", "objective": "regression", "n_estimators": 3,
207209
"min_child_samples": 1, "max_depth": 1, 'num_thread': 1},
208210
data,
209211
)

tests/xgboost/test_xgboost_converters.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.datasets import (
1212
load_diabetes, load_iris, make_classification, load_digits)
1313
from sklearn.model_selection import train_test_split
14-
from xgboost import XGBRegressor, XGBClassifier, train, DMatrix
14+
from xgboost import XGBRegressor, XGBClassifier, train, DMatrix, Booster, train as train_xgb
1515
from sklearn.preprocessing import StandardScaler
1616
from onnx.defs import onnx_opset_version
1717
from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER
@@ -181,7 +181,7 @@ def test_xgboost_booster_classifier_multiclass_softmax(self):
181181
random_state=42, n_informative=3)
182182
x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.5,
183183
random_state=42)
184-
184+
185185
data = DMatrix(x_train, label=y_train)
186186
model = train({'objective': 'multi:softmax',
187187
'n_estimators': 3, 'min_child_samples': 1,
@@ -303,6 +303,44 @@ def test_xgb_empty_tree(self):
303303
assert_almost_equal(xgb.predict_proba(X), res[1])
304304
assert_almost_equal(xgb.predict(X), res[0])
305305

306+
def test_xgb_best_tree_limit(self):
307+
308+
# Train
309+
iris = load_iris()
310+
X, y = iris.data, iris.target
311+
X_train, X_test, y_train, y_test = train_test_split(X, y)
312+
dtrain = DMatrix(X_train, label=y_train)
313+
dtest = DMatrix(X_test)
314+
param = {'objective': 'multi:softmax', 'num_class': 3}
315+
bst_original = train_xgb(param, dtrain, 10)
316+
initial_type = [('float_input', FloatTensorType([None, 4]))]
317+
bst_original.save_model('model.json')
318+
319+
onx_loaded = convert_xgboost(
320+
bst_original, initial_types=initial_type,
321+
target_opset=TARGET_OPSET)
322+
sess = InferenceSession(onx_loaded.SerializeToString())
323+
res = sess.run(None, {'float_input': X_test.astype(np.float32)})
324+
assert_almost_equal(bst_original.predict(dtest, output_margin=True), res[1], decimal=5)
325+
assert_almost_equal(bst_original.predict(dtest), res[0])
326+
327+
# After being restored, the loaded booster is not exactly the same
328+
# in memory. `best_ntree_limit` is not saved during `save_model`.
329+
bst_loaded = Booster()
330+
bst_loaded.load_model('model.json')
331+
bst_loaded.save_model('model2.json')
332+
assert_almost_equal(bst_loaded.predict(dtest, output_margin=True),
333+
bst_original.predict(dtest, output_margin=True), decimal=5)
334+
assert_almost_equal(bst_loaded.predict(dtest), bst_original.predict(dtest))
335+
336+
onx_loaded = convert_xgboost(
337+
bst_loaded, initial_types=initial_type,
338+
target_opset=TARGET_OPSET)
339+
sess = InferenceSession(onx_loaded.SerializeToString())
340+
res = sess.run(None, {'float_input': X_test.astype(np.float32)})
341+
assert_almost_equal(bst_loaded.predict(dtest, output_margin=True), res[1], decimal=5)
342+
assert_almost_equal(bst_loaded.predict(dtest), res[0])
343+
306344

307345
if __name__ == "__main__":
308346
unittest.main()

0 commit comments

Comments
 (0)