Skip to content

Commit adc41ee

Browse files
xadupresdpythontirkarthi
authored
Replace #507 + fix bug with XGBoost converter when base_score is None (#510)
* Replace #507 * fix xgboost converter with xgboost 1.5.0 * fix nan values * update catboost test * update ci Signed-off-by: xavier dupré <[email protected]> Co-authored-by: xavier dupré <[email protected]> Co-authored-by: Karthikeyan Singaravelan <[email protected]>
1 parent 4a44958 commit adc41ee

File tree

5 files changed

+48
-90
lines changed

5 files changed

+48
-90
lines changed

.azure-pipelines/win32-conda-CI.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ jobs:
7777
python -m pip install -r requirements-dev.txt
7878
displayName: 'Install dependencies-dev'
7979
80+
- script: |
81+
call activate py$(python.version)
82+
python -m pip install --upgrade scikit-learn
83+
displayName: 'Install scikit-learn'
84+
8085
- script: |
8186
call activate py$(python.version)
8287
python -m pip install %COREML_PATH%

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ input2 = Input(shape=(D,))
7676
mapped1_2 = sub_model1(input1)
7777
mapped2_2 = sub_model2(input2)
7878
sub_sum = Add()([mapped1_2, mapped2_2])
79-
keras_model = Model(inputs=[input1, input2], output=sub_sum)
79+
keras_model = Model(inputs=[input1, input2], outputs=sub_sum)
8080

8181
# Convert it! The target_opset parameter is optional.
8282
onnx_model = onnxmltools.convert_keras(keras_model, target_opset=7)

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def common_members(xgb_node, inputs):
3333
params = XGBConverter.get_xgb_params(xgb_node)
3434
objective = params["objective"]
3535
base_score = params["base_score"]
36+
if base_score is None:
37+
base_score = 0.5
3638
booster = xgb_node.get_booster()
3739
# The json format was available in October 2017.
3840
# XGBoost 0.7 was the first version released with it.

tests/catboost/test_CatBoost_converter.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,25 @@
44
Tests for CatBoostRegressor and CatBoostClassifier converter.
55
"""
66
import unittest
7-
import numpy
87
import warnings
9-
import catboost
10-
11-
from sklearn.datasets import make_regression, make_classification
8+
from distutils.version import StrictVersion
9+
import numpy
10+
try:
11+
import sklearn
12+
from sklearn.datasets import make_regression, make_classification
13+
except (ImportError, FileNotFoundError):
14+
sklearn = None
15+
try:
16+
import catboost
17+
except (ImportError, FileNotFoundError):
18+
catboost = None
1219
from onnxmltools.convert import convert_catboost
1320
from onnxmltools.utils import dump_data_and_model, dump_single_regression, dump_multiple_classification
1421

1522

1623
class TestCatBoost(unittest.TestCase):
24+
25+
@unittest.skipIf(catboost is None or sklearn is None, reason="catboost not imported")
1726
def test_catboost_regressor(self):
1827
X, y = make_regression(n_samples=100, n_features=4, random_state=0)
1928
catboost_model = catboost.CatBoostRegressor(task_type='CPU', loss_function='RMSE',
@@ -26,11 +35,11 @@ def test_catboost_regressor(self):
2635
self.assertTrue(catboost_onnx is not None)
2736
dump_data_and_model(X.astype(numpy.float32), catboost_model, catboost_onnx, basename="CatBoostReg-Dec4")
2837

38+
@unittest.skipIf(catboost is None or sklearn is None, reason="catboost not imported")
2939
def test_catboost_bin_classifier(self):
3040
import onnxruntime
31-
from distutils.version import StrictVersion
3241

33-
if StrictVersion(onnxruntime.__version__) >= StrictVersion('1.3.0'):
42+
if StrictVersion('.'.join(onnxruntime.__version__.split('.')[:2])) >= StrictVersion('1.3.0'):
3443
X, y = make_classification(n_samples=100, n_features=4, random_state=0)
3544
catboost_model = catboost.CatBoostClassifier(task_type='CPU', loss_function='CrossEntropy',
3645
n_estimators=10, verbose=0)
@@ -45,6 +54,7 @@ def test_catboost_bin_classifier(self):
4554
warnings.warn('Converted CatBoost models for binary classification work with onnxruntime version 1.3.0 or '
4655
'a newer one')
4756

57+
@unittest.skipIf(catboost is None or sklearn is None, reason="catboost not imported")
4858
def test_catboost_multi_classifier(self):
4959
X, y = make_classification(n_samples=10, n_informative=8, n_classes=3, random_state=0)
5060
catboost_model = catboost.CatBoostClassifier(task_type='CPU', loss_function='MultiClass',

tests/xgboost/test_xgboost_converters.py

Lines changed: 24 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,8 @@ def test_xgb_regressor(self):
5353
self.assertTrue(conv_model is not None)
5454
dump_data_and_model(
5555
x_test.astype("float32"),
56-
xgb,
57-
conv_model,
58-
basename="SklearnXGBRegressor-Dec3",
59-
allow_failure="StrictVersion("
60-
"onnx.__version__)"
61-
"< StrictVersion('1.3.0')",
62-
)
56+
xgb, conv_model,
57+
basename="SklearnXGBRegressor-Dec3")
6358

6459
def test_xgb_classifier(self):
6560
xgb, x_test = _fit_classification_model(XGBClassifier(), 2)
@@ -68,14 +63,8 @@ def test_xgb_classifier(self):
6863
target_opset=TARGET_OPSET)
6964
self.assertTrue(conv_model is not None)
7065
dump_data_and_model(
71-
x_test,
72-
xgb,
73-
conv_model,
74-
basename="SklearnXGBClassifier",
75-
allow_failure="StrictVersion("
76-
"onnx.__version__)"
77-
"< StrictVersion('1.3.0')",
78-
)
66+
x_test, xgb, conv_model,
67+
basename="SklearnXGBClassifier")
7968

8069
def test_xgb_classifier_uint8(self):
8170
xgb, x_test = _fit_classification_model(
@@ -85,14 +74,8 @@ def test_xgb_classifier_uint8(self):
8574
target_opset=TARGET_OPSET)
8675
self.assertTrue(conv_model is not None)
8776
dump_data_and_model(
88-
x_test,
89-
xgb,
90-
conv_model,
91-
basename="SklearnXGBClassifier",
92-
allow_failure="StrictVersion("
93-
"onnx.__version__)"
94-
"< StrictVersion('1.3.0')",
95-
)
77+
x_test, xgb, conv_model,
78+
basename="SklearnXGBClassifier")
9679

9780
def test_xgb_classifier_multi(self):
9881
xgb, x_test = _fit_classification_model(XGBClassifier(), 3)
@@ -101,14 +84,8 @@ def test_xgb_classifier_multi(self):
10184
target_opset=TARGET_OPSET)
10285
self.assertTrue(conv_model is not None)
10386
dump_data_and_model(
104-
x_test,
105-
xgb,
106-
conv_model,
107-
basename="SklearnXGBClassifierMulti",
108-
allow_failure="StrictVersion("
109-
"onnx.__version__)"
110-
"< StrictVersion('1.3.0')",
111-
)
87+
x_test, xgb, conv_model,
88+
basename="SklearnXGBClassifierMulti")
11289

11390
def test_xgb_classifier_multi_reglog(self):
11491
xgb, x_test = _fit_classification_model(
@@ -118,14 +95,8 @@ def test_xgb_classifier_multi_reglog(self):
11895
target_opset=TARGET_OPSET)
11996
self.assertTrue(conv_model is not None)
12097
dump_data_and_model(
121-
x_test,
122-
xgb,
123-
conv_model,
124-
basename="SklearnXGBClassifierMultiRegLog",
125-
allow_failure="StrictVersion("
126-
"onnx.__version__)"
127-
"< StrictVersion('1.3.0')",
128-
)
98+
x_test, xgb, conv_model,
99+
basename="SklearnXGBClassifierMultiRegLog")
129100

130101
def test_xgb_classifier_reglog(self):
131102
xgb, x_test = _fit_classification_model(
@@ -135,14 +106,8 @@ def test_xgb_classifier_reglog(self):
135106
target_opset=TARGET_OPSET)
136107
self.assertTrue(conv_model is not None)
137108
dump_data_and_model(
138-
x_test,
139-
xgb,
140-
conv_model,
141-
basename="SklearnXGBClassifierRegLog",
142-
allow_failure="StrictVersion("
143-
"onnx.__version__)"
144-
"< StrictVersion('1.3.0')",
145-
)
109+
x_test, xgb, conv_model,
110+
basename="SklearnXGBClassifierRegLog")
146111

147112
def test_xgb_classifier_multi_str_labels(self):
148113
xgb, x_test = _fit_classification_model(
@@ -152,14 +117,8 @@ def test_xgb_classifier_multi_str_labels(self):
152117
target_opset=TARGET_OPSET)
153118
self.assertTrue(conv_model is not None)
154119
dump_data_and_model(
155-
x_test,
156-
xgb,
157-
conv_model,
158-
basename="SklearnXGBClassifierMultiStrLabels",
159-
allow_failure="StrictVersion("
160-
"onnx.__version__)"
161-
"< StrictVersion('1.3.0')",
162-
)
120+
x_test, xgb, conv_model,
121+
basename="SklearnXGBClassifierMultiStrLabels")
163122

164123
def test_xgb_classifier_multi_discrete_int_labels(self):
165124
iris = load_iris()
@@ -180,13 +139,8 @@ def test_xgb_classifier_multi_discrete_int_labels(self):
180139
self.assertTrue(conv_model is not None)
181140
dump_data_and_model(
182141
x_test.astype("float32"),
183-
xgb,
184-
conv_model,
185-
basename="SklearnXGBClassifierMultiDiscreteIntLabels",
186-
allow_failure="StrictVersion("
187-
"onnx.__version__)"
188-
"< StrictVersion('1.3.0')",
189-
)
142+
xgb, conv_model,
143+
basename="SklearnXGBClassifierMultiDiscreteIntLabels")
190144

191145
def test_xgboost_booster_classifier_bin(self):
192146
x, y = make_classification(n_classes=2, n_features=5,
@@ -202,9 +156,7 @@ def test_xgboost_booster_classifier_bin(self):
202156
[('input', FloatTensorType([None, x.shape[1]]))],
203157
target_opset=TARGET_OPSET)
204158
dump_data_and_model(x_test.astype(np.float32),
205-
model, model_onnx,
206-
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
207-
basename="XGBBoosterMCl")
159+
model, model_onnx, basename="XGBBoosterMCl")
208160

209161
def test_xgboost_booster_classifier_multiclass_softprob(self):
210162
x, y = make_classification(n_classes=3, n_features=5,
@@ -221,9 +173,7 @@ def test_xgboost_booster_classifier_multiclass_softprob(self):
221173
[('input', FloatTensorType([None, x.shape[1]]))],
222174
target_opset=TARGET_OPSET)
223175
dump_data_and_model(x_test.astype(np.float32),
224-
model, model_onnx,
225-
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
226-
basename="XGBBoosterMClSoftProb")
176+
model, model_onnx, basename="XGBBoosterMClSoftProb")
227177

228178
def test_xgboost_booster_classifier_multiclass_softmax(self):
229179
x, y = make_classification(n_classes=3, n_features=5,
@@ -240,9 +190,7 @@ def test_xgboost_booster_classifier_multiclass_softmax(self):
240190
[('input', FloatTensorType([None, x.shape[1]]))],
241191
target_opset=TARGET_OPSET)
242192
dump_data_and_model(x_test.astype(np.float32),
243-
model, model_onnx,
244-
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
245-
basename="XGBBoosterMClSoftMax")
193+
model, model_onnx, basename="XGBBoosterMClSoftMax")
246194

247195
def test_xgboost_booster_classifier_reg(self):
248196
x, y = make_classification(n_classes=2, n_features=5,
@@ -259,9 +207,7 @@ def test_xgboost_booster_classifier_reg(self):
259207
[('input', FloatTensorType([None, x.shape[1]]))],
260208
target_opset=TARGET_OPSET)
261209
dump_data_and_model(x_test.astype(np.float32),
262-
model, model_onnx,
263-
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
264-
basename="XGBBoosterReg")
210+
model, model_onnx, basename="XGBBoosterReg")
265211

266212
def test_xgboost_10(self):
267213
this = os.path.abspath(os.path.dirname(__file__))
@@ -279,9 +225,9 @@ def test_xgboost_10(self):
279225
}
280226

281227
train_df = pandas.read_csv(train)
282-
X_train, y_train = train_df.drop('label', axis=1).values, train_df['label'].values
228+
X_train, y_train = train_df.drop('label', axis=1).values, train_df['label'].fillna(0).values
283229
test_df = pandas.read_csv(test)
284-
X_test, y_test = test_df.drop('label', axis=1).values, test_df['label'].values
230+
X_test, y_test = test_df.drop('label', axis=1).values, test_df['label'].fillna(0).values
285231

286232
regressor = XGBRegressor(verbose=0, objective='reg:squarederror', **param_distributions)
287233
regressor.fit(X_train, y_train)
@@ -292,9 +238,7 @@ def test_xgboost_10(self):
292238
target_opset=TARGET_OPSET)
293239

294240
dump_data_and_model(
295-
X_test.astype(np.float32),
296-
regressor, model_onnx,
297-
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
241+
X_test.astype(np.float32), regressor, model_onnx,
298242
basename="XGBBoosterRegBug")
299243

300244
def test_xgboost_classifier_i5450(self):
@@ -315,9 +259,7 @@ def test_xgboost_classifier_i5450(self):
315259
bst = clr.get_booster()
316260
bst.dump_model('dump.raw.txt')
317261
dump_data_and_model(
318-
X_test.astype(np.float32) + 1e-5,
319-
clr, onx,
320-
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
262+
X_test.astype(np.float32) + 1e-5, clr, onx,
321263
basename="XGBClassifierIris")
322264

323265
def test_xgboost_example_mnist(self):
@@ -342,7 +284,6 @@ def test_xgboost_example_mnist(self):
342284

343285
dump_data_and_model(
344286
X_test.astype(np.float32), clf, onnx_model,
345-
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
346287
basename="XGBoostExample")
347288

348289
def test_xgb_empty_tree(self):

0 commit comments

Comments
 (0)