@@ -40,14 +40,20 @@ def common_members(xgb_node, inputs):
4040 params = XGBConverter .get_xgb_params (xgb_node )
4141 objective = params ["objective" ]
4242 base_score = params ["base_score" ]
43+ if hasattr (xgb_node , "best_ntree_limit" ):
44+ best_ntree_limit = xgb_node .best_ntree_limit
45+ elif hasattr (xgb_node , "best_iteration" ):
46+ best_ntree_limit = xgb_node .best_iteration + 1
47+ else :
48+ best_ntree_limit = params .get ("best_ntree_limit" , None )
4349 if base_score is None :
4450 base_score = 0.5
4551 booster = xgb_node .get_booster ()
4652 # The json format was available in October 2017.
4753 # XGBoost 0.7 was the first version released with it.
4854 js_tree_list = booster .get_dump (with_stats = True , dump_format = "json" )
4955 js_trees = [json .loads (s ) for s in js_tree_list ]
50- return objective , base_score , js_trees
56+ return objective , base_score , js_trees , best_ntree_limit
5157
5258 @staticmethod
5359 def _get_default_tree_attribute_pairs (is_classifier ):
@@ -231,17 +237,17 @@ def _get_default_tree_attribute_pairs():
231237 def convert (scope , operator , container ):
232238 xgb_node = operator .raw_operator
233239 inputs = operator .inputs
234- objective , base_score , js_trees = XGBConverter .common_members (xgb_node , inputs )
240+ objective , base_score , js_trees , best_ntree_limit = XGBConverter .common_members (
241+ xgb_node , inputs
242+ )
235243
236244 if objective in ["reg:gamma" , "reg:tweedie" ]:
237245 raise RuntimeError ("Objective '{}' not supported." .format (objective ))
238246
239247 attr_pairs = XGBRegressorConverter ._get_default_tree_attribute_pairs ()
240248 attr_pairs ["base_values" ] = [base_score ]
241249
242- bst = xgb_node .get_booster ()
243- best_ntree_limit = getattr (bst , "best_ntree_limit" , len (js_trees ))
244- if best_ntree_limit < len (js_trees ):
250+ if best_ntree_limit and best_ntree_limit < len (js_trees ):
245251 js_trees = js_trees [:best_ntree_limit ]
246252
247253 XGBConverter .fill_tree_attributes (
@@ -289,7 +295,9 @@ def convert(scope, operator, container):
289295 xgb_node = operator .raw_operator
290296 inputs = operator .inputs
291297
292- objective , base_score , js_trees = XGBConverter .common_members (xgb_node , inputs )
298+ objective , base_score , js_trees , best_ntree_limit = XGBConverter .common_members (
299+ xgb_node , inputs
300+ )
293301
294302 params = XGBConverter .get_xgb_params (xgb_node )
295303 n_estimators = get_n_estimators_classifier (xgb_node , params , js_trees )
@@ -305,8 +313,9 @@ def convert(scope, operator, container):
305313 else :
306314 ncl = (max (attr_pairs ["class_treeids" ]) + 1 ) // n_estimators
307315
308- bst = xgb_node .get_booster ()
309- best_ntree_limit = getattr (bst , "best_ntree_limit" , len (js_trees )) * ncl
316+ best_ntree_limit = best_ntree_limit or len (js_trees )
317+ if ncl > 0 :
318+ best_ntree_limit *= ncl
310319 if 0 < best_ntree_limit < len (js_trees ):
311320 js_trees = js_trees [:best_ntree_limit ]
312321 attr_pairs = XGBClassifierConverter ._get_default_tree_attribute_pairs ()
0 commit comments