|
6 | 6 | import unittest |
7 | 7 | import numpy as np |
8 | 8 | import pandas |
9 | | -from sklearn.datasets import load_diabetes, load_iris, make_classification |
| 9 | +from sklearn.datasets import ( |
| 10 | + load_diabetes, load_iris, make_classification, load_digits) |
10 | 11 | from sklearn.model_selection import train_test_split |
11 | 12 | from xgboost import XGBRegressor, XGBClassifier, train, DMatrix |
| 13 | +from sklearn.preprocessing import StandardScaler |
12 | 14 | from onnxmltools.convert import convert_xgboost |
13 | 15 | from onnxmltools.convert.common.data_types import FloatTensorType |
14 | 16 | from onnxmltools.utils import dump_data_and_model |
15 | 17 |
|
16 | 18 |
|
17 | | -def _fit_classification_model(model, n_classes, is_str=False): |
| 19 | +def _fit_classification_model(model, n_classes, is_str=False, dtype=None): |
18 | 20 | x, y = make_classification(n_classes=n_classes, n_features=100, |
19 | 21 | n_samples=1000, |
20 | 22 | random_state=42, n_informative=7) |
21 | 23 | y = y.astype(np.str) if is_str else y.astype(np.int64) |
22 | 24 | x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.5, |
23 | 25 | random_state=42) |
| 26 | + if dtype is not None: |
| 27 | + y_train = y_train.astype(dtype) |
24 | 28 | model.fit(x_train, y_train) |
25 | 29 | return model, x_test.astype(np.float32) |
26 | 30 |
|
@@ -67,6 +71,24 @@ def test_xgb_classifier(self): |
67 | 71 | "< StrictVersion('1.3.0')", |
68 | 72 | ) |
69 | 73 |
|
| 74 | + @unittest.skipIf(sys.version_info[0] == 2, |
| 75 | + reason="xgboost converter not tested on python 2") |
| 76 | + def test_xgb_classifier_uint8(self): |
| 77 | + xgb, x_test = _fit_classification_model( |
| 78 | + XGBClassifier(), 2, dtype=np.uint8) |
| 79 | + conv_model = convert_xgboost( |
| 80 | + xgb, initial_types=[('input', FloatTensorType(shape=['None', 'None']))]) |
| 81 | + self.assertTrue(conv_model is not None) |
| 82 | + dump_data_and_model( |
| 83 | + x_test, |
| 84 | + xgb, |
| 85 | + conv_model, |
| 86 | + basename="SklearnXGBClassifier", |
| 87 | + allow_failure="StrictVersion(" |
| 88 | + "onnx.__version__)" |
| 89 | + "< StrictVersion('1.3.0')", |
| 90 | + ) |
| 91 | + |
70 | 92 | @unittest.skipIf(sys.version_info[0] == 2, |
71 | 93 | reason="xgboost converter not tested on python 2") |
72 | 94 | def test_xgb_classifier_multi(self): |
@@ -260,6 +282,30 @@ def test_xgboost_10(self): |
260 | 282 | allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", |
261 | 283 | basename="XGBBoosterRegBug") |
262 | 284 |
|
| 285 | + def test_xgboost_example_mnist(self): |
| 286 | + """ |
| 287 | + Train a simple xgboost model and store associated artefacts. |
| 288 | + """ |
| 289 | + X, y = load_digits(return_X_y=True) |
| 290 | + X_train, X_test, y_train, y_test = train_test_split(X, y) |
| 291 | + X_train = X_train.reshape((X_train.shape[0], -1)) |
| 292 | + X_test = X_test.reshape((X_test.shape[0], -1)) |
| 293 | + |
| 294 | + scaler = StandardScaler() |
| 295 | + X_train = scaler.fit_transform(X_train) |
| 296 | + X_test = scaler.transform(X_test) |
| 297 | + clf = XGBClassifier(objective="multi:softprob", n_jobs=-1) |
| 298 | + clf.fit(X_train, y_train) |
| 299 | + |
| 300 | + sh = [None, X_train.shape[1]] |
| 301 | + onnx_model = convert_xgboost( |
| 302 | + clf, initial_types=[('input', FloatTensorType(sh))]) |
| 303 | + |
| 304 | + dump_data_and_model( |
| 305 | + X_test.astype(np.float32), clf, onnx_model, |
| 306 | + allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", |
| 307 | + basename="XGBoostExample") |
| 308 | + |
263 | 309 |
|
264 | 310 | if __name__ == "__main__": |
265 | 311 | unittest.main() |
0 commit comments