diff --git a/onnxmltools/convert/xgboost/common.py b/onnxmltools/convert/xgboost/common.py index 4a10a48f..8b8fd1a3 100644 --- a/onnxmltools/convert/xgboost/common.py +++ b/onnxmltools/convert/xgboost/common.py @@ -29,9 +29,16 @@ def get_xgb_params(xgb_node): if xgb_node.n_estimators is not None: params["n_estimators"] = xgb_node.n_estimators if "base_score" in config["learner"]["learner_model_param"]: - bs = float(config["learner"]["learner_model_param"]["base_score"]) - # xgboost >= 2.0 - params["base_score"] = bs + base_score = config["learner"]["learner_model_param"]["base_score"] + if(base_score.startswith('[') and base_score.endswith(']')): + # xgboost >= 3.1, see + base_score = [float(score) for score in base_score.strip('[]').split(',')] + if len(base_score) == 1: + base_score = base_score[0] + else: + #xgboost >= 2.0 and < 3.1 + base_score = float(base_score) + params["base_score"] = base_score if "num_target" in config["learner"]["learner_model_param"]: params["n_targets"] = int( config["learner"]["learner_model_param"]["num_target"] @@ -48,6 +55,11 @@ def get_xgb_params(xgb_node): params["best_ntree_limit"] = int(gbp["num_trees"]) return params +def base_score_as_list(base_score): + if isinstance(base_score, list): + return base_score + return [base_score] + def get_n_estimators_classifier(xgb_node, params, js_trees): if "n_estimators" not in params: diff --git a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py index aeec4364..3797d40a 100644 --- a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py +++ b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py @@ -10,7 +10,7 @@ except ImportError: XGBRFClassifier = None from ...common._registration import register_converter -from ..common import get_xgb_params, get_n_estimators_classifier +from ..common import get_xgb_params, get_n_estimators_classifier, base_score_as_list class XGBConverter: @@ -259,7 +259,7 @@ def convert(scope, operator, container): raise RuntimeError("Objective '{}' not supported.".format(objective)) attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs() - attr_pairs["base_values"] = [base_score] + attr_pairs["base_values"] = base_score_as_list(base_score) if best_ntree_limit and best_ntree_limit < len(js_trees): js_trees = js_trees[:best_ntree_limit] @@ -350,17 +350,20 @@ def convert(scope, operator, container): attr_pairs["post_transform"] = "LOGISTIC" attr_pairs["class_ids"] = [0 for v in attr_pairs["class_treeids"]] if js_trees[0].get("leaf", None) == 0: - attr_pairs["base_values"] = [base_score] + attr_pairs["base_values"] = base_score_as_list(base_score) elif base_score != 0.5: # 0.5 -> cst = 0 - cst = -np.log(1 / np.float32(base_score) - 1.0) - attr_pairs["base_values"] = [cst] + cst = -np.log(1 / np.array(base_score_as_list(base_score), dtype=np.float32) - 1.0) + attr_pairs["base_values"] = cst.tolist() else: - attr_pairs["base_values"] = [base_score] + attr_pairs["base_values"] = base_score_as_list(base_score) else: # See https://github.com/dmlc/xgboost/blob/main/src/common/math.h#L35. attr_pairs["post_transform"] = "SOFTMAX" - attr_pairs["base_values"] = [base_score for n in range(ncl)] + if isinstance(base_score, list): + attr_pairs["base_values"] = base_score + else: + attr_pairs["base_values"] = [base_score for n in range(ncl)] attr_pairs["class_ids"] = [v % ncl for v in attr_pairs["class_treeids"]] classes = xgb_node.classes_