Skip to content

Commit 060a71e

Browse files
author
Prabhat
authored
Fixed XGboost classifier converter output labels (#336)
* Fixed XGboost classifier converter output labels * Removed commented out code * Fixed typo * Added unit test with discrete int labels * Fixed XGBoost converters' unit tests
1 parent 303273b commit 060a71e

File tree

3 files changed

+156
-58
lines changed

3 files changed

+156
-58
lines changed

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@
44
# license information.
55
# --------------------------------------------------------------------------
66

7-
import ctypes
8-
import numbers
9-
import numpy
107
import json
11-
from xgboost import XGBRegressor, XGBClassifier
12-
from xgboost.core import _LIB, _check_call, from_cstr_to_pystr
13-
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
8+
import numpy as np
9+
from xgboost import XGBClassifier
1410
from ...common._registration import register_converter
15-
from ...common import utils
1611
from ..common import get_xgb_params
1712

1813

@@ -29,7 +24,7 @@ def get_xgb_params(xgb_node):
2924
def validate(xgb_node):
3025
params = XGBConverter.get_xgb_params(xgb_node)
3126
try:
32-
if not "objective" in params:
27+
if "objective" not in params:
3328
raise AttributeError('ojective')
3429
except AttributeError as e:
3530
raise RuntimeError('Missing attribute in XGBoost model ' + str(e))
@@ -238,7 +233,13 @@ def convert(scope, operator, container):
238233
attr_pairs['class_ids'] = [v % ncl for v in attr_pairs['class_treeids']]
239234
class_labels = list(range(ncl))
240235

241-
attr_pairs['classlabels_int64s'] = class_labels
236+
classes = xgb_node.classes_
237+
if (np.issubdtype(classes.dtype, np.floating) or
238+
np.issubdtype(classes.dtype, np.signedinteger)):
239+
attr_pairs['classlabels_int64s'] = classes.astype('int')
240+
else:
241+
classes = np.array([s.encode('utf-8') for s in classes])
242+
attr_pairs['classlabels_strings'] = classes
242243

243244
# add nodes
244245
if objective == "binary:logistic":
@@ -262,7 +263,6 @@ def convert(scope, operator, container):
262263
raise RuntimeError("Unexpected objective: {0}".format(objective))
263264

264265

265-
266266
def convert_xgboost(scope, operator, container):
267267
xgb_node = operator.raw_operator
268268
if isinstance(xgb_node, XGBClassifier):

onnxmltools/convert/xgboost/shape_calculators/Classifier.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
# license information.
55
# --------------------------------------------------------------------------
66

7+
import numpy as np
78
from ...common._registration import register_shape_calculator
89
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types
9-
from ...common.data_types import Int64TensorType, FloatTensorType, DictionaryType, SequenceType
10+
from ...common.data_types import (
11+
DictionaryType, FloatTensorType, Int64TensorType,
12+
SequenceType, StringTensorType,
13+
)
1014
from ..common import get_xgb_params
1115

1216

@@ -28,7 +32,12 @@ def calculate_xgboost_classifier_output_shapes(operator):
2832
ncl = ntrees // params['n_estimators']
2933
if objective == "reg:logistic" and ncl == 1:
3034
ncl = 2
31-
operator.outputs[0].type = Int64TensorType(shape=[N])
35+
classes = xgb_node.classes_
36+
if (np.issubdtype(classes.dtype, np.floating) or
37+
np.issubdtype(classes.dtype, np.signedinteger)):
38+
operator.outputs[0].type = Int64TensorType(shape=[N])
39+
else:
40+
operator.outputs[0].type = StringTensorType(shape=[N])
3241
operator.outputs[1].type = operator.outputs[1].type = FloatTensorType([N, ncl])
3342

3443

tests/xgboost/test_xgboost_converters.py

Lines changed: 135 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,77 +3,166 @@
33
"""
44
import sys
55
import unittest
6-
from sklearn.datasets import load_iris
6+
import numpy as np
7+
from sklearn.datasets import load_diabetes, load_iris, make_classification
8+
from sklearn.model_selection import train_test_split
79
from xgboost import XGBRegressor, XGBClassifier
810
from onnxmltools.convert import convert_xgboost
911
from onnxmltools.convert.common.data_types import FloatTensorType
10-
from onnxmltools.utils import dump_multiple_classification, dump_single_regression, dump_binary_classification
12+
from onnxmltools.utils import dump_data_and_model
13+
14+
15+
def _fit_classification_model(model, n_classes, is_str=False):
16+
x, y = make_classification(n_classes=n_classes, n_features=100,
17+
n_samples=1000,
18+
random_state=42, n_informative=7)
19+
y = y.astype(np.str) if is_str else y.astype(np.int64)
20+
x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.5,
21+
random_state=42)
22+
model.fit(x_train, y_train)
23+
return model, x_test.astype(np.float32)
1124

1225

1326
class TestXGBoostModels(unittest.TestCase):
1427

15-
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
28+
@unittest.skipIf(sys.version_info[0] == 2,
29+
reason="xgboost converter not tested on python 2")
1630
def test_xgb_regressor(self):
17-
iris = load_iris()
18-
X = iris.data[:, :2]
31+
iris = load_diabetes()
32+
x = iris.data
1933
y = iris.target
20-
34+
x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.5,
35+
random_state=42)
2136
xgb = XGBRegressor()
22-
xgb.fit(X, y)
23-
conv_model = convert_xgboost(xgb, initial_types=[('input', FloatTensorType(shape=[1, 'None']))])
37+
xgb.fit(x_train, y_train)
38+
conv_model = convert_xgboost(
39+
xgb, initial_types=[('input', FloatTensorType(shape=[1, 'None']))])
2440
self.assertTrue(conv_model is not None)
25-
dump_single_regression(xgb, suffix="-Dec4")
41+
dump_data_and_model(
42+
x_test.astype("float32"),
43+
xgb,
44+
conv_model,
45+
basename="SklearnXGBRegressor-Dec4",
46+
allow_failure="StrictVersion("
47+
"onnx.__version__)"
48+
"< StrictVersion('1.3.0')",
49+
)
2650

27-
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
51+
@unittest.skipIf(sys.version_info[0] == 2,
52+
reason="xgboost converter not tested on python 2")
2853
def test_xgb_classifier(self):
29-
iris = load_iris()
30-
X = iris.data[:, :2]
31-
y = iris.target
32-
y[y == 2] = 0
33-
34-
xgb = XGBClassifier()
35-
xgb.fit(X, y)
36-
conv_model = convert_xgboost(xgb, initial_types=[('input', FloatTensorType(shape=[1, 'None']))])
54+
xgb, x_test = _fit_classification_model(XGBClassifier(), 2)
55+
conv_model = convert_xgboost(
56+
xgb, initial_types=[('input', FloatTensorType(shape=[1, 'None']))])
3757
self.assertTrue(conv_model is not None)
38-
dump_binary_classification(xgb)
58+
dump_data_and_model(
59+
x_test,
60+
xgb,
61+
conv_model,
62+
basename="SklearnXGBClassifier",
63+
allow_failure="StrictVersion("
64+
"onnx.__version__)"
65+
"< StrictVersion('1.3.0')",
66+
)
3967

40-
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
68+
@unittest.skipIf(sys.version_info[0] == 2,
69+
reason="xgboost converter not tested on python 2")
4170
def test_xgb_classifier_multi(self):
42-
iris = load_iris()
43-
X = iris.data[:, :2]
44-
y = iris.target
45-
46-
xgb = XGBClassifier()
47-
xgb.fit(X, y)
48-
conv_model = convert_xgboost(xgb, initial_types=[('input', FloatTensorType(shape=[1, 'None']))])
71+
xgb, x_test = _fit_classification_model(XGBClassifier(), 3)
72+
conv_model = convert_xgboost(
73+
xgb, initial_types=[('input', FloatTensorType(shape=[1, 'None']))])
4974
self.assertTrue(conv_model is not None)
50-
dump_multiple_classification(xgb, allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')")
75+
dump_data_and_model(
76+
x_test,
77+
xgb,
78+
conv_model,
79+
basename="SklearnXGBClassifierMulti",
80+
allow_failure="StrictVersion("
81+
"onnx.__version__)"
82+
"< StrictVersion('1.3.0')",
83+
)
5184

52-
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
85+
@unittest.skipIf(sys.version_info[0] == 2,
86+
reason="xgboost converter not tested on python 2")
5387
def test_xgb_classifier_multi_reglog(self):
54-
iris = load_iris()
55-
X = iris.data[:, :2]
56-
y = iris.target
57-
58-
xgb = XGBClassifier(objective='reg:logistic')
59-
xgb.fit(X, y)
60-
conv_model = convert_xgboost(xgb, initial_types=[('input', FloatTensorType(shape=[1, 2]))])
88+
xgb, x_test = _fit_classification_model(
89+
XGBClassifier(objective='reg:logistic'), 4)
90+
conv_model = convert_xgboost(
91+
xgb, initial_types=[('input', FloatTensorType(shape=[1, 2]))])
6192
self.assertTrue(conv_model is not None)
62-
dump_multiple_classification(xgb, suffix="RegLog",
63-
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')")
93+
dump_data_and_model(
94+
x_test,
95+
xgb,
96+
conv_model,
97+
basename="SklearnXGBClassifierMultiRegLog",
98+
allow_failure="StrictVersion("
99+
"onnx.__version__)"
100+
"< StrictVersion('1.3.0')",
101+
)
64102

65-
@unittest.skipIf(sys.version_info[0] == 2, reason="xgboost converter not tested on python 2")
103+
@unittest.skipIf(sys.version_info[0] == 2,
104+
reason="xgboost converter not tested on python 2")
66105
def test_xgb_classifier_reglog(self):
106+
xgb, x_test = _fit_classification_model(
107+
XGBClassifier(objective='reg:logistic'), 2)
108+
conv_model = convert_xgboost(
109+
xgb, initial_types=[('input', FloatTensorType(shape=[1, 2]))])
110+
self.assertTrue(conv_model is not None)
111+
dump_data_and_model(
112+
x_test,
113+
xgb,
114+
conv_model,
115+
basename="SklearnXGBClassifierRegLog",
116+
allow_failure="StrictVersion("
117+
"onnx.__version__)"
118+
"< StrictVersion('1.3.0')",
119+
)
120+
121+
@unittest.skipIf(sys.version_info[0] == 2,
122+
reason="xgboost converter not tested on python 2")
123+
def test_xgb_classifier_multi_str_labels(self):
124+
xgb, x_test = _fit_classification_model(
125+
XGBClassifier(n_estimators=4), 5, is_str=True)
126+
conv_model = convert_xgboost(
127+
xgb, initial_types=[('input', FloatTensorType(shape=[1, 'None']))])
128+
self.assertTrue(conv_model is not None)
129+
dump_data_and_model(
130+
x_test,
131+
xgb,
132+
conv_model,
133+
basename="SklearnXGBClassifierMultiStrLabels",
134+
allow_failure="StrictVersion("
135+
"onnx.__version__)"
136+
"< StrictVersion('1.3.0')",
137+
)
138+
139+
@unittest.skipIf(sys.version_info[0] == 2,
140+
reason="xgboost converter not tested on python 2")
141+
def test_xgb_classifier_multi_discrete_int_labels(self):
67142
iris = load_iris()
68-
X = iris.data[:, :2]
143+
x = iris.data[:, :2]
69144
y = iris.target
70-
y[y == 2] = 0
71-
72-
xgb = XGBClassifier(objective='reg:logistic')
73-
xgb.fit(X, y)
74-
conv_model = convert_xgboost(xgb, initial_types=[('input', FloatTensorType(shape=[1, 2]))])
145+
y[y == 0] = 10
146+
y[y == 1] = 20
147+
y[y == 2] = -30
148+
x_train, x_test, y_train, _ = train_test_split(x,
149+
y,
150+
test_size=0.5,
151+
random_state=42)
152+
xgb = XGBClassifier(n_estimators=3)
153+
xgb.fit(x_train, y_train)
154+
conv_model = convert_xgboost(
155+
xgb, initial_types=[('input', FloatTensorType(shape=[1, 'None']))])
75156
self.assertTrue(conv_model is not None)
76-
dump_binary_classification(xgb, suffix="RegLog")
157+
dump_data_and_model(
158+
x_test.astype("float32"),
159+
xgb,
160+
conv_model,
161+
basename="SklearnXGBClassifierMultiDiscreteIntLabels",
162+
allow_failure="StrictVersion("
163+
"onnx.__version__)"
164+
"< StrictVersion('1.3.0')",
165+
)
77166

78167

79168
if __name__ == "__main__":

0 commit comments

Comments
 (0)