Skip to content

Commit 7858f9f

Browse files
authored
Fix discrepancies with XGBRegressor and xgboost > 2 (#670)
* Fix discrepancies with XGBRegressor and xgboost > 2 Signed-off-by: Xavier Dupre <[email protected]> * improve ci version Signed-off-by: Xavier Dupre <[email protected]> * fix many issues Signed-off-by: Xavier Dupre <[email protected]> * more fixes Signed-off-by: Xavier Dupre <[email protected]> * many fixes for xgboost Signed-off-by: Xavier Dupre <[email protected]> * fix requiremnts. Signed-off-by: Xavier Dupre <[email protected]> * fix converters for old versions of xgboost Signed-off-by: Xavier Dupre <[email protected]> * fix pipeline and base_score Signed-off-by: Xavier Dupre <[email protected]> * poisson Signed-off-by: Xavier Dupre <[email protected]> --------- Signed-off-by: Xavier Dupre <[email protected]>
1 parent 761c1cd commit 7858f9f

File tree

15 files changed

+286
-113
lines changed

15 files changed

+286
-113
lines changed

.azure-pipelines/linux-conda-CI.yml

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,23 @@ jobs:
1515
strategy:
1616
matrix:
1717

18+
Python311-1150-RT1163-xgb2-lgbm40:
19+
python.version: '3.11'
20+
ONNX_PATH: 'onnx==1.15.0'
21+
ONNXRT_PATH: 'onnxruntime==1.16.3'
22+
COREML_PATH: NONE
23+
lightgbm.version: '>=4.0'
24+
xgboost.version: '>=2'
25+
numpy.version: ''
26+
scipy.version: ''
27+
1828
Python311-1150-RT1160-xgb175-lgbm40:
1929
python.version: '3.11'
2030
ONNX_PATH: 'onnx==1.15.0'
2131
ONNXRT_PATH: 'onnxruntime==1.16.2'
2232
COREML_PATH: NONE
2333
lightgbm.version: '>=4.0'
24-
xgboost.version: '>=1.7.5'
34+
xgboost.version: '==1.7.5'
2535
numpy.version: ''
2636
scipy.version: ''
2737

@@ -31,7 +41,7 @@ jobs:
3141
ONNXRT_PATH: 'onnxruntime==1.16.2'
3242
COREML_PATH: NONE
3343
lightgbm.version: '>=4.0'
34-
xgboost.version: '>=1.7.5'
44+
xgboost.version: '==1.7.5'
3545
numpy.version: ''
3646
scipy.version: ''
3747

@@ -41,7 +51,7 @@ jobs:
4151
ONNXRT_PATH: 'onnxruntime==1.15.1'
4252
COREML_PATH: NONE
4353
lightgbm.version: '<4.0'
44-
xgboost.version: '>=1.7.5'
54+
xgboost.version: '==1.7.5'
4555
numpy.version: ''
4656
scipy.version: ''
4757

@@ -51,7 +61,7 @@ jobs:
5161
ONNXRT_PATH: 'onnxruntime==1.14.0'
5262
COREML_PATH: NONE
5363
lightgbm.version: '<4.0'
54-
xgboost.version: '>=1.7.5'
64+
xgboost.version: '==1.7.5'
5565
numpy.version: ''
5666
scipy.version: ''
5767

@@ -61,7 +71,7 @@ jobs:
6171
ONNXRT_PATH: 'onnxruntime==1.15.1'
6272
COREML_PATH: NONE
6373
lightgbm.version: '>=4.0'
64-
xgboost.version: '>=1.7.5'
74+
xgboost.version: '==1.7.5'
6575
numpy.version: ''
6676
scipy.version: '==1.8.0'
6777

.azure-pipelines/win32-conda-CI.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,39 @@ jobs:
2121
ONNXRT_PATH: 'onnxruntime==1.16.2'
2222
COREML_PATH: NONE
2323
numpy.version: ''
24+
xgboost.version: '2.0.2'
2425

2526
Python311-1141-RT1162:
2627
python.version: '3.11'
2728
ONNX_PATH: 'onnx==1.14.1'
2829
ONNXRT_PATH: 'onnxruntime==1.16.2'
2930
COREML_PATH: NONE
3031
numpy.version: ''
32+
xgboost.version: '1.7.5'
3133

3234
Python310-1141-RT1151:
3335
python.version: '3.10'
3436
ONNX_PATH: 'onnx==1.14.1'
3537
ONNXRT_PATH: 'onnxruntime==1.15.1'
3638
COREML_PATH: NONE
3739
numpy.version: ''
40+
xgboost.version: '1.7.5'
3841

3942
Python310-1141-RT1140:
4043
python.version: '3.10'
4144
ONNX_PATH: 'onnx==1.14.1'
4245
ONNXRT_PATH: onnxruntime==1.14.0
4346
COREML_PATH: NONE
4447
numpy.version: ''
48+
xgboost.version: '1.7.5'
4549

4650
Python39-1141-RT1140:
4751
python.version: '3.9'
4852
ONNX_PATH: 'onnx==1.14.1'
4953
ONNXRT_PATH: onnxruntime==1.14.0
5054
COREML_PATH: NONE
5155
numpy.version: ''
56+
xgboost.version: '1.7.5'
5257

5358
maxParallel: 3
5459

@@ -74,6 +79,8 @@ jobs:
7479
- script: |
7580
call activate py$(python.version)
7681
python -m pip install --upgrade scikit-learn
82+
python -m pip install --upgrade lightgbm
83+
python -m pip install "xgboost==$(xgboost.version)"
7784
displayName: 'Install scikit-learn'
7885
7986
- script: |

CHANGELOGS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## 1.12.0
44

5+
* Fix discrepancies with XGBRegressor and xgboost > 2
6+
[#670](https://github.com/onnx/onnxmltools/pull/670)
57
* Support count:poisson for XGBRegressor
68
[#666](https://github.com/onnx/onnxmltools/pull/666)
79
* Supports XGBRFClassifier and XGBRFRegressor

onnxmltools/convert/xgboost/_parse.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,15 @@ def _get_attributes(booster):
7070
ntrees = trees // num_class if num_class > 0 else trees
7171
else:
7272
trees = len(res)
73-
ntrees = booster.best_ntree_limit
74-
num_class = trees // ntrees
73+
ntrees = getattr(booster, "best_ntree_limit", trees)
74+
config = json.loads(booster.save_config())["learner"]["learner_model_param"]
75+
num_class = int(config["num_class"]) if "num_class" in config else 0
76+
if num_class == 0 and ntrees > 0:
77+
num_class = trees // ntrees
7578
if num_class == 0:
7679
raise RuntimeError(
77-
"Unable to retrieve the number of classes, trees=%d, ntrees=%d."
78-
% (trees, ntrees)
80+
f"Unable to retrieve the number of classes, num_class={num_class}, "
81+
f"trees={trees}, ntrees={ntrees}, config={config}."
7982
)
8083

8184
kwargs = atts.copy()
@@ -137,7 +140,7 @@ def __init__(self, booster):
137140
self.operator_name = "XGBRegressor"
138141

139142
def get_xgb_params(self):
140-
return self.kwargs
143+
return {k: v for k, v in self.kwargs.items() if v is not None}
141144

142145
def get_booster(self):
143146
return self.booster_

onnxmltools/convert/xgboost/common.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
Common function to converters and shape calculators.
55
"""
6+
import json
67

78

89
def get_xgb_params(xgb_node):
@@ -15,8 +16,29 @@ def get_xgb_params(xgb_node):
1516
else:
1617
# XGBoost < 0.7
1718
params = xgb_node.__dict__
18-
19+
if hasattr("xgb_node", "save_config"):
20+
config = json.loads(xgb_node.save_config())
21+
else:
22+
config = json.loads(xgb_node.get_booster().save_config())
23+
params = {k: v for k, v in params.items() if v is not None}
24+
num_class = int(config["learner"]["learner_model_param"]["num_class"])
25+
if num_class > 0:
26+
params["num_class"] = num_class
1927
if "n_estimators" not in params and hasattr(xgb_node, "n_estimators"):
2028
# xgboost >= 1.0.2
21-
params["n_estimators"] = xgb_node.n_estimators
29+
if xgb_node.n_estimators is not None:
30+
params["n_estimators"] = xgb_node.n_estimators
31+
if "base_score" in config["learner"]["learner_model_param"]:
32+
bs = float(config["learner"]["learner_model_param"]["base_score"])
33+
# xgboost >= 2.0
34+
params["base_score"] = bs
2235
return params
36+
37+
38+
def get_n_estimators_classifier(xgb_node, params, js_trees):
39+
if "n_estimators" not in params:
40+
num_class = params.get("num_class", 0)
41+
if num_class == 0:
42+
return len(js_trees)
43+
return len(js_trees) // num_class
44+
return params["n_estimators"]

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
except ImportError:
1111
XGBRFClassifier = None
1212
from ...common._registration import register_converter
13-
from ..common import get_xgb_params
13+
from ..common import get_xgb_params, get_n_estimators_classifier
1414

1515

1616
class XGBConverter:
@@ -161,8 +161,7 @@ def _fill_node_attributes(
161161
false_child_id=remap[jsnode["no"]], # ['children'][1]['nodeid'],
162162
weights=None,
163163
weight_id_bias=None,
164-
missing=jsnode.get("missing", -1)
165-
== jsnode["yes"], # ['children'][0]['nodeid'],
164+
missing=jsnode.get("missing", -1) == jsnode["yes"],
166165
hitrate=jsnode.get("cover", 0),
167166
)
168167

@@ -265,8 +264,8 @@ def convert(scope, operator, container):
265264
)
266265

267266
if objective == "count:poisson":
268-
cst = scope.get_unique_variable_name("half")
269-
container.add_initializer(cst, TensorProto.FLOAT, [1], [0.5])
267+
cst = scope.get_unique_variable_name("poisson")
268+
container.add_initializer(cst, TensorProto.FLOAT, [1], [base_score])
270269
new_name = scope.get_unique_variable_name("exp")
271270
container.add_node("Exp", names, [new_name])
272271
container.add_node("Mul", [new_name, cst], operator.output_full_names)
@@ -293,11 +292,18 @@ def convert(scope, operator, container):
293292
objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)
294293

295294
params = XGBConverter.get_xgb_params(xgb_node)
295+
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
296+
num_class = params.get("num_class", None)
297+
296298
attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
297299
XGBConverter.fill_tree_attributes(
298300
js_trees, attr_pairs, [1 for _ in js_trees], True
299301
)
300-
ncl = (max(attr_pairs["class_treeids"]) + 1) // params["n_estimators"]
302+
if num_class is not None:
303+
ncl = num_class
304+
n_estimators = len(js_trees) // ncl
305+
else:
306+
ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators
301307

302308
bst = xgb_node.get_booster()
303309
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl
@@ -310,15 +316,17 @@ def convert(scope, operator, container):
310316

311317
if len(attr_pairs["class_treeids"]) == 0:
312318
raise RuntimeError("XGBoost model is empty.")
319+
313320
if ncl <= 1:
314321
ncl = 2
315322
if objective != "binary:hinge":
316323
# See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23.
317324
attr_pairs["post_transform"] = "LOGISTIC"
318325
attr_pairs["class_ids"] = [0 for v in attr_pairs["class_treeids"]]
319326
if js_trees[0].get("leaf", None) == 0:
320-
attr_pairs["base_values"] = [0.5]
327+
attr_pairs["base_values"] = [base_score]
321328
elif base_score != 0.5:
329+
# 0.5 -> cst = 0
322330
cst = -np.log(1 / np.float32(base_score) - 1.0)
323331
attr_pairs["base_values"] = [cst]
324332
else:
@@ -330,8 +338,10 @@ def convert(scope, operator, container):
330338
attr_pairs["class_ids"] = [v % ncl for v in attr_pairs["class_treeids"]]
331339

332340
classes = xgb_node.classes_
333-
if np.issubdtype(classes.dtype, np.floating) or np.issubdtype(
334-
classes.dtype, np.integer
341+
if (
342+
np.issubdtype(classes.dtype, np.floating)
343+
or np.issubdtype(classes.dtype, np.integer)
344+
or np.issubdtype(classes.dtype, np.bool_)
335345
):
336346
attr_pairs["classlabels_int64s"] = classes.astype("int")
337347
else:
@@ -373,7 +383,7 @@ def convert(scope, operator, container):
373383
"Where", [greater, one, zero], operator.output_full_names[1]
374384
)
375385
elif objective in ("multi:softprob", "multi:softmax"):
376-
ncl = len(js_trees) // params["n_estimators"]
386+
ncl = len(js_trees) // n_estimators
377387
if objective == "multi:softmax":
378388
attr_pairs["post_transform"] = "NONE"
379389
container.add_node(
@@ -385,7 +395,7 @@ def convert(scope, operator, container):
385395
**attr_pairs,
386396
)
387397
elif objective == "reg:logistic":
388-
ncl = len(js_trees) // params["n_estimators"]
398+
ncl = len(js_trees) // n_estimators
389399
if ncl == 1:
390400
ncl = 2
391401
container.add_node(

onnxmltools/convert/xgboost/shape_calculators/Classifier.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Int64TensorType,
99
StringTensorType,
1010
)
11-
from ..common import get_xgb_params
11+
from ..common import get_xgb_params, get_n_estimators_classifier
1212

1313

1414
def calculate_xgboost_classifier_output_shapes(operator):
@@ -22,18 +22,26 @@ def calculate_xgboost_classifier_output_shapes(operator):
2222
params = get_xgb_params(xgb_node)
2323
booster = xgb_node.get_booster()
2424
booster.attributes()
25-
ntrees = len(booster.get_dump(with_stats=True, dump_format="json"))
25+
js_trees = booster.get_dump(with_stats=True, dump_format="json")
26+
ntrees = len(js_trees)
2627
objective = params["objective"]
28+
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
29+
num_class = params.get("num_class", None)
2730

28-
if objective == "binary:logistic":
31+
if num_class is not None:
32+
ncl = num_class
33+
n_estimators = ntrees // ncl
34+
elif objective == "binary:logistic":
2935
ncl = 2
3036
else:
31-
ncl = ntrees // params["n_estimators"]
37+
ncl = ntrees // n_estimators
3238
if objective == "reg:logistic" and ncl == 1:
3339
ncl = 2
3440
classes = xgb_node.classes_
35-
if np.issubdtype(classes.dtype, np.floating) or np.issubdtype(
36-
classes.dtype, np.integer
41+
if (
42+
np.issubdtype(classes.dtype, np.floating)
43+
or np.issubdtype(classes.dtype, np.integer)
44+
or np.issubdtype(classes.dtype, np.bool_)
3745
):
3846
operator.outputs[0].type = Int64TensorType(shape=[N])
3947
else:

onnxmltools/utils/utils_backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def compare_outputs(expected, output, **kwargs):
188188
Disc = kwargs.pop("Disc", False)
189189
Mism = kwargs.pop("Mism", False)
190190
Opp = kwargs.pop("Opp", False)
191+
if hasattr(expected, "dtype") and expected.dtype == numpy.bool_:
192+
expected = expected.astype(numpy.int64)
191193
if Opp and not NoProb:
192194
raise ValueError("Opp is only available if NoProb is True")
193195

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pytest-spark
1515
ruff
1616
scikit-learn>=1.2.0
1717
scipy
18+
skl2onnx
1819
wheel
19-
xgboost==1.7.5
20+
xgboost
2021
onnxruntime

requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
numpy
2-
onnx
3-
skl2onnx
1+
numpy
2+
onnx

0 commit comments

Comments
 (0)