Skip to content

Commit 331df2e

Browse files
xadupresdpython
andauthored
Fix discrepencies when xgboost trees are empty. (#447)
* Fix discrepencies when xgboost trees are empty. Signed-off-by: xavier dupré <[email protected]> * fix wrong converting function Signed-off-by: xavier dupré <[email protected]> * eol Signed-off-by: xavier dupré <[email protected]> Co-authored-by: xavier dupré <[email protected]>
1 parent ccddab5 commit 331df2e

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def convert(scope, operator, container):
242242
# See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23.
243243
attr_pairs['post_transform'] = "LOGISTIC"
244244
attr_pairs['class_ids'] = [0 for v in attr_pairs['class_treeids']]
245+
if js_trees[0].get('leaf', None) == 0:
246+
attr_pairs['base_values'] = [0.5]
245247
else:
246248
# See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L35.
247249
attr_pairs['post_transform'] = "SOFTMAX"

tests/xgboost/test_xgboost_converters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import unittest
88
import numpy as np
9+
from numpy.testing import assert_almost_equal
910
import pandas
1011
from sklearn.datasets import (
1112
load_diabetes, load_iris, make_classification, load_digits)
@@ -325,6 +326,22 @@ def test_xgboost_example_mnist(self):
325326
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
326327
basename="XGBoostExample")
327328

329+
def test_xgb_empty_tree(self):
330+
xgb = XGBClassifier(n_estimators=2, max_depth=2)
331+
332+
# simple dataset
333+
X = [[0, 1], [1, 1], [2, 0]]
334+
X = np.array(X, dtype=np.float32)
335+
y = [0, 1, 0]
336+
xgb.fit(X, y)
337+
conv_model = convert_xgboost(
338+
xgb, initial_types=[
339+
('input', FloatTensorType(shape=[None, X.shape[1]]))])
340+
sess = InferenceSession(conv_model.SerializeToString())
341+
res = sess.run(None, {'input': X.astype(np.float32)})
342+
assert_almost_equal(xgb.predict_proba(X), res[1])
343+
assert_almost_equal(xgb.predict(X), res[0])
344+
328345

329346
if __name__ == "__main__":
330347
unittest.main()

0 commit comments

Comments
 (0)