Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@ def get_xgb_params(xgb_node):
if xgb_node.n_estimators is not None:
params["n_estimators"] = xgb_node.n_estimators
if "base_score" in config["learner"]["learner_model_param"]:
bs = float(config["learner"]["learner_model_param"]["base_score"])
# xgboost >= 2.0
params["base_score"] = bs
base_score = config["learner"]["learner_model_param"]["base_score"]
if(base_score.startswith('[') and base_score.endswith(']')):
# xgboost >= 3.1, see
base_score = [float(score) for score in base_score.strip('[]').split(',')]
if len(base_score) == 1:
base_score = base_score[0]
else:
#xgboost >= 2.0 and < 3.1
base_score = float(base_score)
params["base_score"] = base_score
if "num_target" in config["learner"]["learner_model_param"]:
params["n_targets"] = int(
config["learner"]["learner_model_param"]["num_target"]
Expand All @@ -48,6 +55,11 @@ def get_xgb_params(xgb_node):
params["best_ntree_limit"] = int(gbp["num_trees"])
return params

def base_score_as_list(base_score):
if isinstance(base_score, list):
return base_score
return [base_score]


def get_n_estimators_classifier(xgb_node, params, js_trees):
if "n_estimators" not in params:
Expand Down
17 changes: 10 additions & 7 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except ImportError:
XGBRFClassifier = None
from ...common._registration import register_converter
from ..common import get_xgb_params, get_n_estimators_classifier
from ..common import get_xgb_params, get_n_estimators_classifier, base_score_as_list


class XGBConverter:
Expand Down Expand Up @@ -259,7 +259,7 @@ def convert(scope, operator, container):
raise RuntimeError("Objective '{}' not supported.".format(objective))

attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs()
attr_pairs["base_values"] = [base_score]
attr_pairs["base_values"] = base_score_as_list(base_score)

if best_ntree_limit and best_ntree_limit < len(js_trees):
js_trees = js_trees[:best_ntree_limit]
Expand Down Expand Up @@ -350,17 +350,20 @@ def convert(scope, operator, container):
attr_pairs["post_transform"] = "LOGISTIC"
attr_pairs["class_ids"] = [0 for v in attr_pairs["class_treeids"]]
if js_trees[0].get("leaf", None) == 0:
attr_pairs["base_values"] = [base_score]
attr_pairs["base_values"] = base_score_as_list(base_score)
elif base_score != 0.5:
# 0.5 -> cst = 0
cst = -np.log(1 / np.float32(base_score) - 1.0)
attr_pairs["base_values"] = [cst]
cst = -np.log(1 / np.array(base_score_as_list(base_score), dtype=np.float32) - 1.0)
attr_pairs["base_values"] = cst.tolist()
else:
attr_pairs["base_values"] = [base_score]
attr_pairs["base_values"] = base_score_as_list(base_score)
else:
# See https://github.com/dmlc/xgboost/blob/main/src/common/math.h#L35.
attr_pairs["post_transform"] = "SOFTMAX"
attr_pairs["base_values"] = [base_score for n in range(ncl)]
if isinstance(base_score, list):
attr_pairs["base_values"] = base_score
else:
attr_pairs["base_values"] = [base_score for n in range(ncl)]
attr_pairs["class_ids"] = [v % ncl for v in attr_pairs["class_treeids"]]

classes = xgb_node.classes_
Expand Down