3636from Orange .widgets .widget import OWWidget
3737
3838
39+ def get_tree_train_params (model ):
40+ ln = json .loads (model .skl_model .get_booster ().save_config ())["learner" ]
41+ try :
42+ return ln ["gradient_booster" ]["tree_train_param" ]
43+ except KeyError :
44+ return ln ["gradient_booster" ]["updater" ]["grow_colmaker" ]["train_param" ]
45+
46+
3947def create_parent (editor_class ):
4048 class DummyWidget (OWWidget ):
4149 name = "Mock"
@@ -158,9 +166,7 @@ def test_default_parameters_cls(self):
158166 booster = XGBClassifier ()
159167 model = booster (data )
160168 params = model .skl_model .get_params ()
161- booster_params = json .loads (model .skl_model .get_booster ().save_config ())
162- updater = booster_params ["learner" ]["gradient_booster" ]["updater" ]
163- tp = updater ["grow_colmaker" ]["train_param" ]
169+ tp = get_tree_train_params (model )
164170 self .assertEqual (params ["n_estimators" ], self .editor .n_estimators )
165171 self .assertEqual (
166172 round (float (tp ["learning_rate" ]), 1 ), self .editor .learning_rate
@@ -178,9 +184,7 @@ def test_default_parameters_reg(self):
178184 booster = XGBRegressor ()
179185 model = booster (data )
180186 params = model .skl_model .get_params ()
181- booster_params = json .loads (model .skl_model .get_booster ().save_config ())
182- updater = booster_params ["learner" ]["gradient_booster" ]["updater" ]
183- tp = updater ["grow_colmaker" ]["train_param" ]
187+ tp = get_tree_train_params (model )
184188 self .assertEqual (params ["n_estimators" ], self .editor .n_estimators )
185189 self .assertEqual (
186190 round (float (tp ["learning_rate" ]), 1 ), self .editor .learning_rate
@@ -223,9 +227,7 @@ def test_default_parameters_cls(self):
223227 booster = XGBRFClassifier ()
224228 model = booster (data )
225229 params = model .skl_model .get_params ()
226- booster_params = json .loads (model .skl_model .get_booster ().save_config ())
227- updater = booster_params ["learner" ]["gradient_booster" ]["updater" ]
228- tp = updater ["grow_colmaker" ]["train_param" ]
230+ tp = get_tree_train_params (model )
229231 self .assertEqual (params ["n_estimators" ], self .editor .n_estimators )
230232 self .assertEqual (
231233 round (float (tp ["learning_rate" ]), 1 ), self .editor .learning_rate
@@ -243,9 +245,7 @@ def test_default_parameters_reg(self):
243245 booster = XGBRFRegressor ()
244246 model = booster (data )
245247 params = model .skl_model .get_params ()
246- booster_params = json .loads (model .skl_model .get_booster ().save_config ())
247- updater = booster_params ["learner" ]["gradient_booster" ]["updater" ]
248- tp = updater ["grow_colmaker" ]["train_param" ]
248+ tp = get_tree_train_params (model )
249249 self .assertEqual (params ["n_estimators" ], self .editor .n_estimators )
250250 self .assertEqual (
251251 round (float (tp ["learning_rate" ]), 1 ), self .editor .learning_rate
0 commit comments