Skip to content

Commit a849358

Browse files
xadupresdpythonwenbingl
authored
Fixes #421, support unsigned integer as class type (#426)
* Add example on xgboost mnist Signed-off-by: xavier dupré <[email protected]> * fixes #421, allow unsigned integers as class label Signed-off-by: xavier dupré <[email protected]> Co-authored-by: xavier dupré <[email protected]> Co-authored-by: Wenbing Li <[email protected]>
1 parent fdbd2b4 commit a849358

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def convert(scope, operator, container):
242242

243243
classes = xgb_node.classes_
244244
if (np.issubdtype(classes.dtype, np.floating) or
245-
np.issubdtype(classes.dtype, np.signedinteger)):
245+
np.issubdtype(classes.dtype, np.integer)):
246246
attr_pairs['classlabels_int64s'] = classes.astype('int')
247247
else:
248248
classes = np.array([s.encode('utf-8') for s in classes])

onnxmltools/convert/xgboost/shape_calculators/Classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def calculate_xgboost_classifier_output_shapes(operator):
3434
ncl = 2
3535
classes = xgb_node.classes_
3636
if (np.issubdtype(classes.dtype, np.floating) or
37-
np.issubdtype(classes.dtype, np.signedinteger)):
37+
np.issubdtype(classes.dtype, np.integer)):
3838
operator.outputs[0].type = Int64TensorType(shape=[N])
3939
else:
4040
operator.outputs[0].type = StringTensorType(shape=[N])

tests/xgboost/test_xgboost_converters.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,25 @@
66
import unittest
77
import numpy as np
88
import pandas
9-
from sklearn.datasets import load_diabetes, load_iris, make_classification
9+
from sklearn.datasets import (
10+
load_diabetes, load_iris, make_classification, load_digits)
1011
from sklearn.model_selection import train_test_split
1112
from xgboost import XGBRegressor, XGBClassifier, train, DMatrix
13+
from sklearn.preprocessing import StandardScaler
1214
from onnxmltools.convert import convert_xgboost
1315
from onnxmltools.convert.common.data_types import FloatTensorType
1416
from onnxmltools.utils import dump_data_and_model
1517

1618

17-
def _fit_classification_model(model, n_classes, is_str=False):
19+
def _fit_classification_model(model, n_classes, is_str=False, dtype=None):
1820
x, y = make_classification(n_classes=n_classes, n_features=100,
1921
n_samples=1000,
2022
random_state=42, n_informative=7)
2123
y = y.astype(np.str) if is_str else y.astype(np.int64)
2224
x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.5,
2325
random_state=42)
26+
if dtype is not None:
27+
y_train = y_train.astype(dtype)
2428
model.fit(x_train, y_train)
2529
return model, x_test.astype(np.float32)
2630

@@ -67,6 +71,24 @@ def test_xgb_classifier(self):
6771
"< StrictVersion('1.3.0')",
6872
)
6973

74+
@unittest.skipIf(sys.version_info[0] == 2,
75+
reason="xgboost converter not tested on python 2")
76+
def test_xgb_classifier_uint8(self):
77+
xgb, x_test = _fit_classification_model(
78+
XGBClassifier(), 2, dtype=np.uint8)
79+
conv_model = convert_xgboost(
80+
xgb, initial_types=[('input', FloatTensorType(shape=['None', 'None']))])
81+
self.assertTrue(conv_model is not None)
82+
dump_data_and_model(
83+
x_test,
84+
xgb,
85+
conv_model,
86+
basename="SklearnXGBClassifier",
87+
allow_failure="StrictVersion("
88+
"onnx.__version__)"
89+
"< StrictVersion('1.3.0')",
90+
)
91+
7092
@unittest.skipIf(sys.version_info[0] == 2,
7193
reason="xgboost converter not tested on python 2")
7294
def test_xgb_classifier_multi(self):
@@ -260,6 +282,30 @@ def test_xgboost_10(self):
260282
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
261283
basename="XGBBoosterRegBug")
262284

285+
def test_xgboost_example_mnist(self):
286+
"""
287+
Train a simple xgboost model and store associated artefacts.
288+
"""
289+
X, y = load_digits(return_X_y=True)
290+
X_train, X_test, y_train, y_test = train_test_split(X, y)
291+
X_train = X_train.reshape((X_train.shape[0], -1))
292+
X_test = X_test.reshape((X_test.shape[0], -1))
293+
294+
scaler = StandardScaler()
295+
X_train = scaler.fit_transform(X_train)
296+
X_test = scaler.transform(X_test)
297+
clf = XGBClassifier(objective="multi:softprob", n_jobs=-1)
298+
clf.fit(X_train, y_train)
299+
300+
sh = [None, X_train.shape[1]]
301+
onnx_model = convert_xgboost(
302+
clf, initial_types=[('input', FloatTensorType(sh))])
303+
304+
dump_data_and_model(
305+
X_test.astype(np.float32), clf, onnx_model,
306+
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
307+
basename="XGBoostExample")
308+
263309

264310
if __name__ == "__main__":
265311
unittest.main()

0 commit comments

Comments
 (0)