Skip to content

Commit 18c4dc1

Browse files
authored
Fixes #396, xgboost converter for xgboost >= 1.0.2 (#397)
1 parent 4e7c0f0 commit 18c4dc1

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

onnxmltools/convert/xgboost/common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,9 @@ def get_xgb_params(xgb_node):
1212
else:
1313
# XGBoost < 0.7
1414
params = xgb_node.__dict__
15-
15+
16+
if ('n_estimators' not in params and
17+
hasattr(xgb_node, 'n_estimators')):
18+
# xgboost >= 1.0.2
19+
params['n_estimators'] = xgb_node.n_estimators
1620
return params

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ scikit-learn
1414
scipy
1515
svm
1616
wheel
17-
xgboost<=1.0.2
17+
xgboost

0 commit comments

Comments
 (0)