22
33import json
44import numpy as np
5+ from onnx import TensorProto
56from xgboost import XGBClassifier
67from ...common ._registration import register_converter
78from ..common import get_xgb_params
@@ -241,14 +242,17 @@ def convert(scope, operator, container):
241242 raise RuntimeError ("XGBoost model is empty." )
242243 if ncl <= 1 :
243244 ncl = 2
244- # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23.
245- attr_pairs ['post_transform' ] = "LOGISTIC"
246- attr_pairs ['class_ids' ] = [0 for v in attr_pairs ['class_treeids' ]]
247- if js_trees [0 ].get ('leaf' , None ) == 0 :
248- attr_pairs ['base_values' ] = [0.5 ]
249- elif base_score != 0.5 :
250- cst = - np .log (1 / np .float32 (base_score ) - 1. )
251- attr_pairs ['base_values' ] = [cst ]
245+ if objective != 'binary:hinge' :
246+ # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23.
247+ attr_pairs ['post_transform' ] = "LOGISTIC"
248+ attr_pairs ['class_ids' ] = [0 for v in attr_pairs ['class_treeids' ]]
249+ if js_trees [0 ].get ('leaf' , None ) == 0 :
250+ attr_pairs ['base_values' ] = [0.5 ]
251+ elif base_score != 0.5 :
252+ cst = - np .log (1 / np .float32 (base_score ) - 1. )
253+ attr_pairs ['base_values' ] = [cst ]
254+ else :
255+ attr_pairs ['base_values' ] = [base_score ]
252256 else :
253257 # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L35.
254258 attr_pairs ['post_transform' ] = "SOFTMAX"
@@ -264,13 +268,33 @@ def convert(scope, operator, container):
264268 attr_pairs ['classlabels_strings' ] = classes
265269
266270 # add nodes
267- if objective == "binary:logistic" :
271+ if objective in ( "binary:logistic" , "binary:hinge" ) :
268272 ncl = 2
269- container .add_node ('TreeEnsembleClassifier' , operator .input_full_names ,
270- operator .output_full_names ,
273+ if objective == "binary:hinge" :
274+ attr_pairs ['post_transform' ] = 'NONE'
275+ output_names = [operator .output_full_names [0 ],
276+ scope .get_unique_variable_name ("output_prob" )]
277+ else :
278+ output_names = operator .output_full_names
279+ container .add_node ('TreeEnsembleClassifier' ,
280+ operator .input_full_names ,
281+ output_names ,
271282 op_domain = 'ai.onnx.ml' ,
272283 name = scope .get_unique_operator_name ('TreeEnsembleClassifier' ),
273284 ** attr_pairs )
285+ if objective == "binary:hinge" :
286+ if container .target_opset < 9 :
287+ raise RuntimeError (
288+ f"hinge function cannot be implemented because "
289+ f"opset={ container .target_opset } <9." )
290+ zero = scope .get_unique_variable_name ("zero" )
291+ one = scope .get_unique_variable_name ("one" )
292+ container .add_initializer (zero , TensorProto .FLOAT , [1 ], [0. ])
293+ container .add_initializer (one , TensorProto .FLOAT , [1 ], [1. ])
294+ greater = scope .get_unique_variable_name ("output_prob" )
295+ container .add_node ("Greater" , [output_names [1 ], zero ], [greater ])
296+ container .add_node ('Where' , [greater , one , zero ],
297+ operator .output_full_names [1 ])
274298 elif objective in ("multi:softprob" , "multi:softmax" ):
275299 ncl = len (js_trees ) // params ['n_estimators' ]
276300 if objective == 'multi:softmax' :
0 commit comments