1010except ImportError :
1111 XGBRFClassifier = None
1212from ...common ._registration import register_converter
13- from ..common import get_xgb_params
13+ from ..common import get_xgb_params , get_n_estimators_classifier
1414
1515
1616class XGBConverter :
@@ -161,8 +161,7 @@ def _fill_node_attributes(
161161 false_child_id = remap [jsnode ["no" ]], # ['children'][1]['nodeid'],
162162 weights = None ,
163163 weight_id_bias = None ,
164- missing = jsnode .get ("missing" , - 1 )
165- == jsnode ["yes" ], # ['children'][0]['nodeid'],
164+ missing = jsnode .get ("missing" , - 1 ) == jsnode ["yes" ],
166165 hitrate = jsnode .get ("cover" , 0 ),
167166 )
168167
@@ -265,8 +264,8 @@ def convert(scope, operator, container):
265264 )
266265
267266 if objective == "count:poisson" :
268- cst = scope .get_unique_variable_name ("half " )
269- container .add_initializer (cst , TensorProto .FLOAT , [1 ], [0.5 ])
267+ cst = scope .get_unique_variable_name ("poisson " )
268+ container .add_initializer (cst , TensorProto .FLOAT , [1 ], [base_score ])
270269 new_name = scope .get_unique_variable_name ("exp" )
271270 container .add_node ("Exp" , names , [new_name ])
272271 container .add_node ("Mul" , [new_name , cst ], operator .output_full_names )
@@ -293,11 +292,18 @@ def convert(scope, operator, container):
293292 objective , base_score , js_trees = XGBConverter .common_members (xgb_node , inputs )
294293
295294 params = XGBConverter .get_xgb_params (xgb_node )
295+ n_estimators = get_n_estimators_classifier (xgb_node , params , js_trees )
296+ num_class = params .get ("num_class" , None )
297+
296298 attr_pairs = XGBClassifierConverter ._get_default_tree_attribute_pairs ()
297299 XGBConverter .fill_tree_attributes (
298300 js_trees , attr_pairs , [1 for _ in js_trees ], True
299301 )
300- ncl = (max (attr_pairs ["class_treeids" ]) + 1 ) // params ["n_estimators" ]
302+ if num_class is not None :
303+ ncl = num_class
304+ n_estimators = len (js_trees ) // ncl
305+ else :
306+ ncl = (max (attr_pairs ["class_treeids" ]) + 1 ) // n_estimators
301307
302308 bst = xgb_node .get_booster ()
303309 best_ntree_limit = getattr (bst , "best_ntree_limit" , len (js_trees )) * ncl
@@ -310,15 +316,17 @@ def convert(scope, operator, container):
310316
311317 if len (attr_pairs ["class_treeids" ]) == 0 :
312318 raise RuntimeError ("XGBoost model is empty." )
319+
313320 if ncl <= 1 :
314321 ncl = 2
315322 if objective != "binary:hinge" :
316323 # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23.
317324 attr_pairs ["post_transform" ] = "LOGISTIC"
318325 attr_pairs ["class_ids" ] = [0 for v in attr_pairs ["class_treeids" ]]
319326 if js_trees [0 ].get ("leaf" , None ) == 0 :
320- attr_pairs ["base_values" ] = [0.5 ]
327+ attr_pairs ["base_values" ] = [base_score ]
321328 elif base_score != 0.5 :
329+ # 0.5 -> cst = 0
322330 cst = - np .log (1 / np .float32 (base_score ) - 1.0 )
323331 attr_pairs ["base_values" ] = [cst ]
324332 else :
@@ -330,8 +338,10 @@ def convert(scope, operator, container):
330338 attr_pairs ["class_ids" ] = [v % ncl for v in attr_pairs ["class_treeids" ]]
331339
332340 classes = xgb_node .classes_
333- if np .issubdtype (classes .dtype , np .floating ) or np .issubdtype (
334- classes .dtype , np .integer
341+ if (
342+ np .issubdtype (classes .dtype , np .floating )
343+ or np .issubdtype (classes .dtype , np .integer )
344+ or np .issubdtype (classes .dtype , np .bool_ )
335345 ):
336346 attr_pairs ["classlabels_int64s" ] = classes .astype ("int" )
337347 else :
@@ -373,7 +383,7 @@ def convert(scope, operator, container):
373383 "Where" , [greater , one , zero ], operator .output_full_names [1 ]
374384 )
375385 elif objective in ("multi:softprob" , "multi:softmax" ):
376- ncl = len (js_trees ) // params [ " n_estimators" ]
386+ ncl = len (js_trees ) // n_estimators
377387 if objective == "multi:softmax" :
378388 attr_pairs ["post_transform" ] = "NONE"
379389 container .add_node (
@@ -385,7 +395,7 @@ def convert(scope, operator, container):
385395 ** attr_pairs ,
386396 )
387397 elif objective == "reg:logistic" :
388- ncl = len (js_trees ) // params [ " n_estimators" ]
398+ ncl = len (js_trees ) // n_estimators
389399 if ncl == 1 :
390400 ncl = 2
391401 container .add_node (
0 commit comments