Skip to content

Commit e273afb

Browse files
authored
Support objective='reg:logistic' for XGBoost converters (#270)
Support objective='reg:logistic' for XGBoost converters
1 parent 9faf603 commit e273afb

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,10 @@ def convert(scope, operator, container):
219219
objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)
220220

221221
params = XGBConverter.get_xgb_params(xgb_node)
222-
222+
223223
attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
224224
XGBConverter.fill_tree_attributes(js_trees, attr_pairs, [1 for _ in js_trees], True)
225-
225+
226226
if len(attr_pairs['class_treeids']) == 0:
227227
raise RuntimeError("XGBoost model is empty.")
228228
ncl = (max(attr_pairs['class_treeids']) + 1) // params['n_estimators']
@@ -251,6 +251,13 @@ def convert(scope, operator, container):
251251
container.add_node('TreeEnsembleClassifier', operator.input_full_names,
252252
operator.output_full_names,
253253
op_domain='ai.onnx.ml', **attr_pairs)
254+
elif objective == "reg:logistic":
255+
ncl = len(js_trees) // params['n_estimators']
256+
if ncl == 1:
257+
ncl = 2
258+
container.add_node('TreeEnsembleClassifier', operator.input_full_names,
259+
operator.output_full_names,
260+
op_domain='ai.onnx.ml', **attr_pairs)
254261
else:
255262
raise RuntimeError("Unexpected objective: {0}".format(objective))
256263

onnxmltools/convert/xgboost/shape_calculators/Classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def calculate_xgboost_classifier_output_shapes(operator):
2626
ncl = 2
2727
else:
2828
ncl = ntrees // params['n_estimators']
29+
if objective == "reg:logistic" and ncl == 1:
30+
ncl = 2
2931
operator.outputs[0].type = Int64TensorType(shape=[N])
3032
operator.outputs[1].type = operator.outputs[1].type = FloatTensorType([N, ncl])
3133

tests/xgboost/test_xgboost_converters.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class TestXGBoostModels(unittest.TestCase):
1414

15-
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converted not tested on python 2")
15+
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
1616
def test_xgb_regressor(self):
1717
iris = load_iris()
1818
X = iris.data[:, :2]
@@ -24,7 +24,7 @@ def test_xgb_regressor(self):
2424
self.assertTrue(conv_model is not None)
2525
dump_single_regression(xgb, suffix="-Dec4")
2626

27-
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converted not tested on python 2")
27+
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
2828
def test_xgb_classifier(self):
2929
iris = load_iris()
3030
X = iris.data[:, :2]
@@ -35,9 +35,9 @@ def test_xgb_classifier(self):
3535
xgb.fit(X, y)
3636
conv_model = convert_xgboost(xgb, initial_types=[('input', FloatTensorType(shape=[1, 'None']))])
3737
self.assertTrue(conv_model is not None)
38-
dump_binary_classification(xgb, verbose=True)
38+
dump_binary_classification(xgb)
3939

40-
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converted not tested on python 2")
40+
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
4141
def test_xgb_classifier_multi(self):
4242
iris = load_iris()
4343
X = iris.data[:, :2]
@@ -49,6 +49,32 @@ def test_xgb_classifier_multi(self):
4949
self.assertTrue(conv_model is not None)
5050
dump_multiple_classification(xgb, allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')")
5151

52+
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
53+
def test_xgb_classifier_multi_reglog(self):
54+
iris = load_iris()
55+
X = iris.data[:, :2]
56+
y = iris.target
57+
58+
xgb = XGBClassifier(objective='reg:logistic')
59+
xgb.fit(X, y)
60+
conv_model = convert_xgboost(xgb, initial_types=[('input', FloatTensorType(shape=[1, 2]))])
61+
self.assertTrue(conv_model is not None)
62+
dump_multiple_classification(xgb, suffix="RegLog",
63+
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')")
64+
65+
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
66+
def test_xgb_classifier_reglog(self):
67+
iris = load_iris()
68+
X = iris.data[:, :2]
69+
y = iris.target
70+
y[y == 2] = 0
71+
72+
xgb = XGBClassifier(objective='reg:logistic')
73+
xgb.fit(X, y)
74+
conv_model = convert_xgboost(xgb, initial_types=[('input', FloatTensorType(shape=[1, 2]))])
75+
self.assertTrue(conv_model is not None)
76+
dump_binary_classification(xgb, suffix="RegLog")
77+
5278

5379
if __name__ == "__main__":
5480
unittest.main()

0 commit comments

Comments
 (0)