Skip to content

Commit d7db0ff

Browse files
authored
Fix conversion of Booster for xgboost>=1.6.1 (#567)
* fix base_score for binary classification Signed-off-by: xadupre <[email protected]> * Update for xgboost 1.6.1 Signed-off-by: xadupre <[email protected]> * update ci Signed-off-by: xadupre <[email protected]> * fix for catboost Signed-off-by: xadupre <[email protected]> * fix for svm Signed-off-by: xadupre <[email protected]> * lint Signed-off-by: xadupre <[email protected]>
1 parent a699341 commit d7db0ff

File tree

6 files changed

+95
-44
lines changed

6 files changed

+95
-44
lines changed

.azure-pipelines/linux-conda-CI.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,19 @@ jobs:
1515
strategy:
1616
matrix:
1717

18-
Python39-1120-RT1110:
18+
Python39-1120-RT1110-xgb161:
1919
python.version: '3.9'
2020
ONNX_PATH: 'onnx==1.12.0' #'-i https://test.pypi.org/simple/ onnx==1.12.0rc4'
2121
ONNXRT_PATH: onnxruntime==1.11.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003'
2222
COREML_PATH: git+https://github.com/apple/[email protected]
23+
xgboost.version: '>=1.6.1'
24+
25+
Python39-1120-RT1110-xgb142:
26+
python.version: '3.9'
27+
ONNX_PATH: 'onnx==1.12.0' #'-i https://test.pypi.org/simple/ onnx==1.12.0rc4'
28+
ONNXRT_PATH: onnxruntime==1.11.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003'
29+
COREML_PATH: git+https://github.com/apple/[email protected]
30+
xgboost.version: '==1.4.2'
2331

2432
Python39-1110-RT1110:
2533
python.version: '3.9'
@@ -126,7 +134,10 @@ jobs:
126134
export PYTHONPATH=.
127135
python -c "import onnxruntime;print('onnx:',onnx.__version__)"
128136
python -c "import onnxconverter_common;print('cc:',onnxconverter_common.__version__)"
137+
python -c "import onnx;print('onnx:',onnx.__version__)"
129138
python -c "import onnxruntime;print('ort:',onnxruntime.__version__)"
139+
python -c "import xgboost;print('xgboost:',xgboost.__version__)"
140+
python -c "import lightgbm;print('lightgbm:',lightgbm.__version__)"
130141
displayName: 'version'
131142
132143
- script: |

.azure-pipelines/win32-conda-CI.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ jobs:
122122
python -m pip install -e .
123123
export PYTHONPATH=.
124124
python -c "import onnxconverter_common;print(onnxconverter_common.__version__)"
125+
python -c "import onnx;print(onnx.__version__)"
125126
python -c "import onnxruntime;print(onnxruntime.__version__)"
127+
python -c "import xgboost;print(xgboost.__version__)"
128+
python -c "import lightgbm;print(lightgbm.__version__)"
126129
displayName: 'version'
127130
128131
- script: |

onnxmltools/convert/xgboost/_parse.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import json
44
import re
5+
import pprint
6+
from packaging.version import Version
57
import numpy as np
6-
from xgboost import XGBRegressor, XGBClassifier
8+
from xgboost import XGBRegressor, XGBClassifier, __version__
79
from onnxconverter_common.data_types import FloatTensorType
810
from ..common._container import XGBoostModelContainer
911
from ..common._topology import Topology
@@ -27,23 +29,36 @@ def _append_covers(node):
2729

2830

2931
def _get_attributes(booster):
30-
# num_class
31-
state = booster.__getstate__()
32-
bstate = bytes(state['handle'])
33-
reg = re.compile(b'("tree_info":\\[[0-9,]*\\])')
34-
objs = list(set(reg.findall(bstate)))
35-
assert len(objs) == 1, 'Missing required property "tree_info".'
36-
tree_info = json.loads("{{{}}}".format(objs[0].decode('ascii')))['tree_info']
37-
num_class = len(set(tree_info))
38-
3932
atts = booster.attributes()
4033
dp = booster.get_dump(dump_format='json', with_stats=True)
4134
res = [json.loads(d) for d in dp]
42-
trees = len(res)
43-
try:
35+
36+
# num_class
37+
if Version(__version__) < Version('1.5'):
38+
state = booster.__getstate__()
39+
bstate = bytes(state['handle'])
40+
reg = re.compile(b'("tree_info":\\[[0-9,]*\\])')
41+
objs = list(set(reg.findall(bstate)))
42+
if len(objs) != 1:
43+
raise RuntimeError(
44+
"Unable to retrieve the tree coefficients from\n%s"
45+
"" % bstate.decode("ascii", errors="ignore"))
46+
tree_info = json.loads("{{{}}}".format(objs[0].decode('ascii')))['tree_info']
47+
num_class = len(set(tree_info))
48+
trees = len(res)
49+
try:
50+
ntrees = booster.best_ntree_limit
51+
except AttributeError:
52+
ntrees = trees // num_class if num_class > 0 else trees
53+
else:
54+
trees = len(res)
4455
ntrees = booster.best_ntree_limit
45-
except AttributeError:
46-
ntrees = trees // num_class if num_class > 0 else trees
56+
num_class = trees // ntrees
57+
if num_class == 0:
58+
raise RuntimeError(
59+
"Unable to retrieve the number of classes, trees=%d, ntrees=%d." % (
60+
trees, ntrees))
61+
4762
kwargs = atts.copy()
4863
kwargs['feature_names'] = booster.feature_names
4964
kwargs['n_estimators'] = ntrees
@@ -62,14 +77,23 @@ def _get_attributes(booster):
6277
# classification
6378
kwargs['num_class'] = num_class
6479
if num_class != 1:
65-
reg = re.compile(b'(multi:[a-z]{1,15})')
66-
objs = list(set(reg.findall(bstate)))
67-
if len(objs) == 1:
68-
kwargs["objective"] = objs[0].decode('ascii')
80+
if Version(__version__) < Version('1.5'):
81+
reg = re.compile(b'(multi:[a-z]{1,15})')
82+
objs = list(set(reg.findall(bstate)))
83+
if len(objs) == 1:
84+
kwargs["objective"] = objs[0].decode('ascii')
85+
else:
86+
raise RuntimeError(
87+
"Unable to guess objective in %r (trees=%r, ntrees=%r, num_class=%r)"
88+
"." % (objs, trees, ntrees, kwargs['num_class']))
6989
else:
70-
raise RuntimeError(
71-
"Unable to guess objective in %r (trees=%r, ntrees=%r, num_class=%r)"
72-
"." % (objs, trees, ntrees, kwargs['num_class']))
90+
att = json.loads(booster.save_config())
91+
kwargs["objective"] = att['learner']['objective']['name']
92+
nc = int(att['learner']['learner_model_param']['num_class'])
93+
if nc != num_class:
94+
raise RuntimeError(
95+
"Mismatched value %r != %r from\n%s" % (
96+
nc, num_class, pprint.pformat(att)))
7397
else:
7498
kwargs["objective"] = "binary:logistic"
7599

onnxmltools/utils/tests_helper.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,18 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
113113
prediction = [model.predict(datax)]
114114
elif hasattr(model, "predict_proba"):
115115
# Classifier
116-
prediction = [model.predict(data), model.predict_proba(data)]
116+
if hasattr(model, 'get_params'):
117+
params = model.get_params()
118+
if 'objective' in params:
119+
objective = params['objective']
120+
if objective == "multi:softmax":
121+
prediction = [model.predict(data)]
122+
else:
123+
prediction = [model.predict(data), model.predict_proba(data)]
124+
else:
125+
prediction = [model.predict(data), model.predict_proba(data)]
126+
else:
127+
prediction = [model.predict(data), model.predict_proba(data)]
117128
elif hasattr(model, "predict_with_probabilities"):
118129
# Classifier that returns all in one go
119130
prediction = model.predict_with_probabilities(data)

onnxmltools/utils/utils_backend.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Helpers to test runtimes.
55
"""
66
import os
7-
import sys
87
import glob
98
import pickle
109
import packaging.version as pv
@@ -75,13 +74,9 @@ def compare_backend(backend, test, decimal=5, options=None, verbose=False, conte
7574
if the comparison failed.
7675
"""
7776
if backend == "onnxruntime":
78-
if sys.version_info[0] == 2:
79-
# onnxruntime is not available on Python 2.
80-
return
8177
from .utils_backend_onnxruntime import compare_runtime
8278
return compare_runtime(test, decimal, options, verbose)
83-
else:
84-
raise ValueError("Does not support backend '{0}'.".format(backend))
79+
raise ValueError("Does not support backend '{0}'.".format(backend))
8580

8681

8782
def search_converted_models(root=None):

tests/xgboost/test_xgboost_converters.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,10 @@ def test_xgb_classifier_reglog(self):
109109
x_test, xgb, conv_model,
110110
basename="SklearnXGBClassifierRegLog")
111111

112-
def test_xgb_classifier_multi_str_labels(self):
113-
xgb, x_test = _fit_classification_model(
114-
XGBClassifier(n_estimators=4), 5, is_str=True)
115-
conv_model = convert_xgboost(
116-
xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))],
117-
target_opset=TARGET_OPSET)
118-
self.assertTrue(conv_model is not None)
119-
dump_data_and_model(
120-
x_test, xgb, conv_model,
121-
basename="SklearnXGBClassifierMultiStrLabels")
122-
123112
def test_xgb_classifier_multi_discrete_int_labels(self):
124113
iris = load_iris()
125114
x = iris.data[:, :2]
126115
y = iris.target
127-
y[y == 0] = 10
128-
y[y == 1] = 20
129-
y[y == 2] = -30
130116
x_train, x_test, y_train, _ = train_test_split(x,
131117
y,
132118
test_size=0.5,
@@ -241,7 +227,7 @@ def test_xgboost_10(self):
241227
X_test.astype(np.float32), regressor, model_onnx,
242228
basename="XGBBoosterRegBug")
243229

244-
def test_xgboost_classifier_i5450(self):
230+
def test_xgboost_classifier_i5450_softmax(self):
245231
iris = load_iris()
246232
X, y = iris.data, iris.target
247233
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
@@ -255,6 +241,26 @@ def test_xgboost_classifier_i5450(self):
255241
predict_list = [1., 20., 466., 0.]
256242
predict_array = np.array(predict_list).reshape((1,-1)).astype(np.float32)
257243
pred_onx = sess.run([label_name], {input_name: predict_array})[0]
244+
bst = clr.get_booster()
245+
bst.dump_model('dump.raw.txt')
246+
dump_data_and_model(
247+
X_test.astype(np.float32) + 1e-5, clr, onx,
248+
basename="XGBClassifierIris-Out0")
249+
250+
def test_xgboost_classifier_i5450(self):
251+
iris = load_iris()
252+
X, y = iris.data, iris.target
253+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
254+
clr = XGBClassifier(objective="multi:softprob", max_depth=1, n_estimators=2)
255+
clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40)
256+
initial_type = [('float_input', FloatTensorType([None, 4]))]
257+
onx = convert_xgboost(clr, initial_types=initial_type, target_opset=TARGET_OPSET)
258+
sess = InferenceSession(onx.SerializeToString())
259+
input_name = sess.get_inputs()[0].name
260+
label_name = sess.get_outputs()[1].name
261+
predict_list = [1., 20., 466., 0.]
262+
predict_array = np.array(predict_list).reshape((1,-1)).astype(np.float32)
263+
pred_onx = sess.run([label_name], {input_name: predict_array})[0]
258264
pred_xgboost = sessresults=clr.predict_proba(predict_array)
259265
bst = clr.get_booster()
260266
bst.dump_model('dump.raw.txt')
@@ -364,4 +370,5 @@ def test_onnxrt_python_xgbclassifier(self):
364370

365371

366372
if __name__ == "__main__":
373+
# TestXGBoostModels().test_xgboost_booster_classifier_multiclass_softprob()
367374
unittest.main()

0 commit comments

Comments
 (0)