Skip to content

Commit c1651b1

Browse files
authored
Fix discovering classifier objective (#480)
* Fix scoring classifier objective
1 parent d981dac commit c1651b1

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

onnxmltools/convert/xgboost/_parse.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,16 @@ def _get_attributes(booster):
5959
reg = re.compile(b'(multi:[a-z]{1,15})')
6060
objs = list(set(reg.findall(bstate)))
6161
if len(objs) != 1:
62-
raise RuntimeError(
63-
"Unable to guess objective in {}.".format(objs))
64-
kwargs['num_class'] = trees // ntrees
65-
kwargs["objective"] = objs[0].decode('ascii')
62+
if '"name":"binary:logistic"' in str(bstate):
63+
kwargs['num_class'] = 1
64+
kwargs["objective"] = "binary:logistic"
65+
else:
66+
raise RuntimeError(
67+
"Unable to guess objective in %r (trees=%r, ntrees=%r)"
68+
"." % (objs, trees, ntrees))
69+
else:
70+
kwargs['num_class'] = trees // ntrees
71+
kwargs["objective"] = objs[0].decode('ascii')
6672
else:
6773
kwargs['num_class'] = 1
6874
kwargs["objective"] = "binary:logistic"

0 commit comments

Comments
 (0)