Skip to content

Commit b04d05b

Browse files
xaduprewenbingl
andauthored
Fix xgboost converter, check value of best_ntree_limit (#429)
Co-authored-by: Wenbing Li <[email protected]>
1 parent bf8e1c4 commit b04d05b

File tree

2 files changed

+45
-33
lines changed

2 files changed

+45
-33
lines changed

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ def convert(scope, operator, container):
188188

189189
attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs()
190190
attr_pairs['base_values'] = [base_score]
191+
192+
bst = xgb_node.get_booster()
193+
best_ntree_limit = getattr(bst, 'best_ntree_limit', len(js_trees))
194+
if best_ntree_limit < len(js_trees):
195+
js_trees = js_trees[:best_ntree_limit]
196+
191197
XGBConverter.fill_tree_attributes(js_trees, attr_pairs, [1 for _ in js_trees], False)
192198

193199
# add nodes
@@ -222,13 +228,19 @@ def convert(scope, operator, container):
222228
objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)
223229

224230
params = XGBConverter.get_xgb_params(xgb_node)
225-
226231
attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
227232
XGBConverter.fill_tree_attributes(js_trees, attr_pairs, [1 for _ in js_trees], True)
233+
ncl = (max(attr_pairs['class_treeids']) + 1) // params['n_estimators']
234+
235+
bst = xgb_node.get_booster()
236+
best_ntree_limit = getattr(bst, 'best_ntree_limit', len(js_trees)) * ncl
237+
if best_ntree_limit < len(js_trees):
238+
js_trees = js_trees[:best_ntree_limit]
239+
attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
240+
XGBConverter.fill_tree_attributes(js_trees, attr_pairs, [1 for _ in js_trees], True)
228241

229242
if len(attr_pairs['class_treeids']) == 0:
230243
raise RuntimeError("XGBoost model is empty.")
231-
ncl = (max(attr_pairs['class_treeids']) + 1) // params['n_estimators']
232244
if ncl <= 1:
233245
ncl = 2
234246
# See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23.

tests/xgboost/test_xgboost_converters.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from onnxmltools.convert import convert_xgboost
1515
from onnxmltools.convert.common.data_types import FloatTensorType
1616
from onnxmltools.utils import dump_data_and_model
17+
from onnxruntime import InferenceSession
1718

1819

1920
def _fit_classification_model(model, n_classes, is_str=False, dtype=None):
@@ -31,8 +32,6 @@ def _fit_classification_model(model, n_classes, is_str=False, dtype=None):
3132

3233
class TestXGBoostModels(unittest.TestCase):
3334

34-
@unittest.skipIf(sys.version_info[0] == 2,
35-
reason="xgboost converter not tested on python 2")
3635
def test_xgb_regressor(self):
3736
iris = load_diabetes()
3837
x = iris.data
@@ -42,7 +41,7 @@ def test_xgb_regressor(self):
4241
xgb = XGBRegressor()
4342
xgb.fit(x_train, y_train)
4443
conv_model = convert_xgboost(
45-
xgb, initial_types=[('input', FloatTensorType(shape=['None', 'None']))])
44+
xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))])
4645
self.assertTrue(conv_model is not None)
4746
dump_data_and_model(
4847
x_test.astype("float32"),
@@ -54,12 +53,10 @@ def test_xgb_regressor(self):
5453
"< StrictVersion('1.3.0')",
5554
)
5655

57-
@unittest.skipIf(sys.version_info[0] == 2,
58-
reason="xgboost converter not tested on python 2")
5956
def test_xgb_classifier(self):
6057
xgb, x_test = _fit_classification_model(XGBClassifier(), 2)
6158
conv_model = convert_xgboost(
62-
xgb, initial_types=[('input', FloatTensorType(shape=['None', 'None']))])
59+
xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))])
6360
self.assertTrue(conv_model is not None)
6461
dump_data_and_model(
6562
x_test,
@@ -71,8 +68,6 @@ def test_xgb_classifier(self):
7168
"< StrictVersion('1.3.0')",
7269
)
7370

74-
@unittest.skipIf(sys.version_info[0] == 2,
75-
reason="xgboost converter not tested on python 2")
7671
def test_xgb_classifier_uint8(self):
7772
xgb, x_test = _fit_classification_model(
7873
XGBClassifier(), 2, dtype=np.uint8)
@@ -89,12 +84,10 @@ def test_xgb_classifier_uint8(self):
8984
"< StrictVersion('1.3.0')",
9085
)
9186

92-
@unittest.skipIf(sys.version_info[0] == 2,
93-
reason="xgboost converter not tested on python 2")
9487
def test_xgb_classifier_multi(self):
9588
xgb, x_test = _fit_classification_model(XGBClassifier(), 3)
9689
conv_model = convert_xgboost(
97-
xgb, initial_types=[('input', FloatTensorType(shape=['None', 'None']))])
90+
xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))])
9891
self.assertTrue(conv_model is not None)
9992
dump_data_and_model(
10093
x_test,
@@ -106,13 +99,11 @@ def test_xgb_classifier_multi(self):
10699
"< StrictVersion('1.3.0')",
107100
)
108101

109-
@unittest.skipIf(sys.version_info[0] == 2,
110-
reason="xgboost converter not tested on python 2")
111102
def test_xgb_classifier_multi_reglog(self):
112103
xgb, x_test = _fit_classification_model(
113104
XGBClassifier(objective='reg:logistic'), 4)
114105
conv_model = convert_xgboost(
115-
xgb, initial_types=[('input', FloatTensorType(shape=['None', 'None']))])
106+
xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))])
116107
self.assertTrue(conv_model is not None)
117108
dump_data_and_model(
118109
x_test,
@@ -124,13 +115,11 @@ def test_xgb_classifier_multi_reglog(self):
124115
"< StrictVersion('1.3.0')",
125116
)
126117

127-
@unittest.skipIf(sys.version_info[0] == 2,
128-
reason="xgboost converter not tested on python 2")
129118
def test_xgb_classifier_reglog(self):
130119
xgb, x_test = _fit_classification_model(
131120
XGBClassifier(objective='reg:logistic'), 2)
132121
conv_model = convert_xgboost(
133-
xgb, initial_types=[('input', FloatTensorType(shape=['None', 'None']))])
122+
xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))])
134123
self.assertTrue(conv_model is not None)
135124
dump_data_and_model(
136125
x_test,
@@ -142,13 +131,11 @@ def test_xgb_classifier_reglog(self):
142131
"< StrictVersion('1.3.0')",
143132
)
144133

145-
@unittest.skipIf(sys.version_info[0] == 2,
146-
reason="xgboost converter not tested on python 2")
147134
def test_xgb_classifier_multi_str_labels(self):
148135
xgb, x_test = _fit_classification_model(
149136
XGBClassifier(n_estimators=4), 5, is_str=True)
150137
conv_model = convert_xgboost(
151-
xgb, initial_types=[('input', FloatTensorType(shape=['None', 'None']))])
138+
xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))])
152139
self.assertTrue(conv_model is not None)
153140
dump_data_and_model(
154141
x_test,
@@ -160,8 +147,6 @@ def test_xgb_classifier_multi_str_labels(self):
160147
"< StrictVersion('1.3.0')",
161148
)
162149

163-
@unittest.skipIf(sys.version_info[0] == 2,
164-
reason="xgboost converter not tested on python 2")
165150
def test_xgb_classifier_multi_discrete_int_labels(self):
166151
iris = load_iris()
167152
x = iris.data[:, :2]
@@ -176,7 +161,7 @@ def test_xgb_classifier_multi_discrete_int_labels(self):
176161
xgb = XGBClassifier(n_estimators=3)
177162
xgb.fit(x_train, y_train)
178163
conv_model = convert_xgboost(
179-
xgb, initial_types=[('input', FloatTensorType(shape=['None', 'None']))])
164+
xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))])
180165
self.assertTrue(conv_model is not None)
181166
dump_data_and_model(
182167
x_test.astype("float32"),
@@ -188,8 +173,6 @@ def test_xgb_classifier_multi_discrete_int_labels(self):
188173
"< StrictVersion('1.3.0')",
189174
)
190175

191-
@unittest.skipIf(sys.version_info[0] == 2,
192-
reason="xgboost converter not tested on python 2")
193176
def test_xgboost_booster_classifier_bin(self):
194177
x, y = make_classification(n_classes=2, n_features=5,
195178
n_samples=100,
@@ -207,8 +190,6 @@ def test_xgboost_booster_classifier_bin(self):
207190
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
208191
basename="XGBBoosterMCl")
209192

210-
@unittest.skipIf(sys.version_info[0] == 2,
211-
reason="xgboost converter not tested on python 2")
212193
def test_xgboost_booster_classifier_multiclass(self):
213194
x, y = make_classification(n_classes=3, n_features=5,
214195
n_samples=100,
@@ -227,8 +208,6 @@ def test_xgboost_booster_classifier_multiclass(self):
227208
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
228209
basename="XGBBoosterMCl")
229210

230-
@unittest.skipIf(sys.version_info[0] == 2,
231-
reason="xgboost converter not tested on python 2")
232211
def test_xgboost_booster_classifier_reg(self):
233212
x, y = make_classification(n_classes=2, n_features=5,
234213
n_samples=100,
@@ -247,8 +226,6 @@ def test_xgboost_booster_classifier_reg(self):
247226
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
248227
basename="XGBBoosterReg")
249228

250-
@unittest.skipIf(sys.version_info[0] == 2,
251-
reason="xgboost converter not tested on python 2")
252229
def test_xgboost_10(self):
253230
this = os.path.abspath(os.path.dirname(__file__))
254231
train = os.path.join(this, "input_fail_train.csv")
@@ -282,6 +259,29 @@ def test_xgboost_10(self):
282259
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
283260
basename="XGBBoosterRegBug")
284261

262+
def test_xgboost_classifier_i5450(self):
263+
iris = load_iris()
264+
X, y = iris.data, iris.target
265+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
266+
clr = XGBClassifier(objective="multi:softmax", max_depth=1, n_estimators=2)
267+
clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40)
268+
initial_type = [('float_input', FloatTensorType([None, 4]))]
269+
onx = convert_xgboost(clr, initial_types=initial_type)
270+
sess = InferenceSession(onx.SerializeToString())
271+
input_name = sess.get_inputs()[0].name
272+
label_name = sess.get_outputs()[1].name
273+
predict_list = [1., 20., 466., 0.]
274+
predict_array = np.array(predict_list).reshape((1,-1)).astype(np.float32)
275+
pred_onx = sess.run([label_name], {input_name: predict_array})[0]
276+
pred_xgboost = sessresults=clr.predict_proba(predict_array)
277+
bst = clr.get_booster()
278+
bst.dump_model('dump.raw.txt')
279+
dump_data_and_model(
280+
X_test.astype(np.float32) + 1e-5,
281+
clr, onx,
282+
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
283+
basename="XGBClassifierIris")
284+
285285
def test_xgboost_example_mnist(self):
286286
"""
287287
Train a simple xgboost model and store associated artefacts.

0 commit comments

Comments
 (0)