Skip to content

Commit 24d1d99

Browse files
authored
fix base_score for binary classification (#566)
Signed-off-by: xadupre <[email protected]>
1 parent facefb2 commit 24d1d99

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def convert(scope, operator, container):
246246
attr_pairs['class_ids'] = [0 for v in attr_pairs['class_treeids']]
247247
if js_trees[0].get('leaf', None) == 0:
248248
attr_pairs['base_values'] = [0.5]
249+
elif base_score != 0.5:
250+
cst = - np.log(1 / np.float32(base_score) - 1.)
251+
attr_pairs['base_values'] = [cst]
249252
else:
250253
# See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L35.
251254
attr_pairs['post_transform'] = "SOFTMAX"

tests/xgboost/test_xgboost_converters.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,27 @@ def test_xgb_best_tree_limit(self):
341341
assert_almost_equal(bst_loaded.predict(dtest, output_margin=True), res[1], decimal=5)
342342
assert_almost_equal(bst_loaded.predict(dtest), res[0])
343343

344+
def test_onnxrt_python_xgbclassifier(self):
345+
x = np.random.randn(100, 10).astype(np.float32)
346+
y = ((x.sum(axis=1) + np.random.randn(x.shape[0]) / 50 + 0.5) >= 0).astype(np.int64)
347+
x_train, x_test, y_train, y_test = train_test_split(x, y)
348+
bmy = np.mean(y_train)
349+
350+
for bm, n_est in [(None, 1), (None, 3), (bmy, 1), (bmy, 3)]:
351+
model_skl = XGBClassifier(n_estimators=n_est,
352+
learning_rate=0.01,
353+
subsample=0.5, objective="binary:logistic",
354+
base_score=bm, max_depth=2)
355+
model_skl.fit(x_train, y_train, eval_set=[(x_test, y_test)], verbose=0)
356+
357+
model_onnx_skl = convert_xgboost(
358+
model_skl, initial_types=[('X', FloatTensorType([None, x.shape[1]]))],
359+
target_opset=TARGET_OPSET)
360+
with self.subTest(base_score=bm, n_estimators=n_est):
361+
oinf = InferenceSession(model_onnx_skl.SerializeToString())
362+
res2 = oinf.run(None, {'X': x_test})
363+
assert_almost_equal(model_skl.predict_proba(x_test), res2[1])
364+
344365

345366
if __name__ == "__main__":
346367
unittest.main()

0 commit comments

Comments
 (0)