Skip to content

Commit ccddab5

Browse files
xadupresdpython
andauthored
Fix #400, support multi:softmax objective (#442)
* investigate xgboost issues Signed-off-by: xavier dupré <[email protected]> * fix softmax score Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * test nightly build Signed-off-by: xavier dupré <[email protected]> * restore ci Signed-off-by: xavier dupré <[email protected]> * restore one file Signed-off-by: xavier dupré <[email protected]> Co-authored-by: xavier dupré <[email protected]>
1 parent 5805a9e commit ccddab5

File tree

5 files changed

+43
-7
lines changed

5 files changed

+43
-7
lines changed

.azure-pipelines/linux-conda-CI.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ jobs:
4949
ONNXRT_PATH: onnxruntime==1.6.0
5050
COREML_PATH: git+https://github.com/apple/[email protected]
5151
xgboost.version: '>=1.0'
52+
Python37-180-RT160-xgb11:
53+
python.version: '3.7'
54+
ONNX_PATH: onnx==1.8.0
55+
ONNXRT_PATH: onnxruntime==1.6.0
56+
COREML_PATH: git+https://github.com/apple/[email protected]
57+
xgboost.version: '<1.2'
5258
maxParallel: 3
5359

5460
steps:

onnxmltools/convert/xgboost/_parse.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import json
4+
import re
45
import numpy as np
56
from xgboost import XGBRegressor, XGBClassifier
67
from onnxconverter_common.data_types import FloatTensorType
@@ -53,8 +54,15 @@ def _get_attributes(booster):
5354
# classification
5455
kwargs['num_target'] = 0
5556
if trees > ntrees > 0:
57+
state = booster.__getstate__()
58+
bstate = bytes(state['handle'])
59+
reg = re.compile(b'(multi:[a-z]{1,15})')
60+
objs = list(set(reg.findall(bstate)))
61+
if len(objs) != 1:
62+
raise RuntimeError(
63+
"Unable to guess objective in {}.".format(objs))
5664
kwargs['num_class'] = trees // ntrees
57-
kwargs["objective"] = "multi:softprob"
65+
kwargs["objective"] = objs[0].decode('ascii')
5866
else:
5967
kwargs['num_class'] = 1
6068
kwargs["objective"] = "binary:logistic"

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def convert(scope, operator, container):
245245
else:
246246
# See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L35.
247247
attr_pairs['post_transform'] = "SOFTMAX"
248-
# attr_pairs['base_values'] = [base_score for n in range(ncl)]
248+
attr_pairs['base_values'] = [base_score for n in range(ncl)]
249249
attr_pairs['class_ids'] = [v % ncl for v in attr_pairs['class_treeids']]
250250

251251
classes = xgb_node.classes_
@@ -264,8 +264,10 @@ def convert(scope, operator, container):
264264
op_domain='ai.onnx.ml',
265265
name=scope.get_unique_operator_name('TreeEnsembleClassifier'),
266266
**attr_pairs)
267-
elif objective == "multi:softprob":
267+
elif objective in ("multi:softprob", "multi:softmax"):
268268
ncl = len(js_trees) // params['n_estimators']
269+
if objective == 'multi:softmax':
270+
attr_pairs['post_transform'] = 'NONE'
269271
container.add_node('TreeEnsembleClassifier', operator.input_full_names,
270272
operator.output_full_names,
271273
op_domain='ai.onnx.ml',

onnxmltools/utils/tests_helper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,12 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
9797
if model_dict['objective'].startswith('binary'):
9898
score = model.predict(datax)
9999
prediction = [score > 0.5, numpy.vstack([1-score, score]).T]
100-
elif model_dict['objective'].startswith('multi'):
100+
elif model_dict['objective'].startswith('multi:softprob'):
101101
score = model.predict(datax)
102102
prediction = [score.argmax(axis=1), score]
103+
elif model_dict['objective'].startswith('multi:softmax'):
104+
score = model.predict(datax, output_margin=True)
105+
prediction = [score.argmax(axis=1), score]
103106
else:
104107
prediction = [model.predict(datax)]
105108
elif hasattr(model, "predict_proba"):

tests/xgboost/test_xgboost_converters.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Tests scilit-learn's tree-based methods' converters.
55
"""
66
import os
7-
import sys
87
import unittest
98
import numpy as np
109
import pandas
@@ -192,7 +191,7 @@ def test_xgboost_booster_classifier_bin(self):
192191
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
193192
basename="XGBBoosterMCl")
194193

195-
def test_xgboost_booster_classifier_multiclass(self):
194+
def test_xgboost_booster_classifier_multiclass_softprob(self):
196195
x, y = make_classification(n_classes=3, n_features=5,
197196
n_samples=100,
198197
random_state=42, n_informative=3)
@@ -208,7 +207,25 @@ def test_xgboost_booster_classifier_multiclass(self):
208207
dump_data_and_model(x_test.astype(np.float32),
209208
model, model_onnx,
210209
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
211-
basename="XGBBoosterMCl")
210+
basename="XGBBoosterMClSoftProb")
211+
212+
def test_xgboost_booster_classifier_multiclass_softmax(self):
213+
x, y = make_classification(n_classes=3, n_features=5,
214+
n_samples=100,
215+
random_state=42, n_informative=3)
216+
x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.5,
217+
random_state=42)
218+
219+
data = DMatrix(x_train, label=y_train)
220+
model = train({'objective': 'multi:softmax',
221+
'n_estimators': 3, 'min_child_samples': 1,
222+
'num_class': 3}, data)
223+
model_onnx = convert_xgboost(model, 'tree-based classifier',
224+
[('input', FloatTensorType([None, x.shape[1]]))])
225+
dump_data_and_model(x_test.astype(np.float32),
226+
model, model_onnx,
227+
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
228+
basename="XGBBoosterMClSoftMax")
212229

213230
def test_xgboost_booster_classifier_reg(self):
214231
x, y = make_classification(n_classes=2, n_features=5,

0 commit comments

Comments
 (0)