11import json
22import unittest
3- import sys
43from typing import Type
5- from unittest .mock import patch , Mock
4+ from unittest .mock import Mock
65
76from Orange .classification import GBClassifier
87
9- try :
10- from Orange .classification import XGBClassifier , XGBRFClassifier
11- except ImportError :
12- XGBClassifier = XGBRFClassifier = None
13- try :
14- from Orange .classification import CatGBClassifier
15- except ImportError :
16- CatGBClassifier = None
8+ from Orange .classification import XGBClassifier , XGBRFClassifier
9+ from Orange .classification import CatGBClassifier
1710from Orange .data import Table
18- from Orange .modelling import GBLearner
1911from Orange .preprocess .score import Scorer
2012from Orange .regression import GBRegressor
2113
22- try :
23- from Orange .regression import XGBRegressor , XGBRFRegressor
24- except ImportError :
25- XGBRegressor = XGBRFRegressor = None
26- try :
27- from Orange .regression import CatGBRegressor
28- except ImportError :
29- CatGBRegressor = None
14+ from Orange .regression import XGBRegressor , XGBRFRegressor
15+ from Orange .regression import CatGBRegressor
3016from Orange .widgets .model .owgradientboosting import OWGradientBoosting , \
3117 LearnerItemModel , GBLearnerEditor , XGBLearnerEditor , XGBRFLearnerEditor , \
3218 CatGBLearnerEditor , BaseEditor
@@ -65,16 +51,6 @@ def test_model(self):
6551 self .assertEqual (model .item (i ).isEnabled (),
6652 classifiers [i ] is not None )
6753
68- @patch ("Orange.widgets.model.owgradientboosting.LearnerItemModel.LEARNERS" ,
69- [(GBLearner , "" , "" ),
70- (None , "Gradient Boosting (catboost)" , "catboost" )])
71- def test_missing_lib (self ):
72- widget = create_parent (CatGBLearnerEditor )
73- model = LearnerItemModel (widget )
74- self .assertEqual (model .rowCount (), 2 )
75- self .assertTrue (model .item (0 ).isEnabled ())
76- self .assertFalse (model .item (1 ).isEnabled ())
77-
7854
7955class BaseEditorTest (GuiTest ):
8056 EditorClass : Type [BaseEditor ] = None
@@ -146,7 +122,6 @@ def test_arguments(self):
146122 "colsample_bynode" : 1 , "subsample" : 1 , "random_state" : 0 }
147123 self .assertDictEqual (self .editor .get_arguments (), args )
148124
149- @unittest .skipIf (XGBClassifier is None , "Missing 'xgboost' package" )
150125 def test_learner_parameters (self ):
151126 params = (("Method" , "Extreme Gradient Boosting (xgboost)" ),
152127 ("Number of trees" , 100 ),
@@ -160,7 +135,6 @@ def test_learner_parameters(self):
160135 ("Fraction of features for each split" , 1 ))
161136 self .assertTupleEqual (self .editor .get_learner_parameters (), params )
162137
163- @unittest .skipIf (XGBClassifier is None , "Missing 'xgboost' package" )
164138 def test_default_parameters_cls (self ):
165139 data = Table ("heart_disease" )
166140 booster = XGBClassifier ()
@@ -178,7 +152,6 @@ def test_default_parameters_cls(self):
178152 self .assertEqual (int (tp ["colsample_bylevel" ]), self .editor .colsample_bylevel )
179153 self .assertEqual (int (tp ["colsample_bynode" ]), self .editor .colsample_bynode )
180154
181- @unittest .skipIf (XGBRegressor is None , "Missing 'xgboost' package" )
182155 def test_default_parameters_reg (self ):
183156 data = Table ("housing" )
184157 booster = XGBRegressor ()
@@ -206,7 +179,6 @@ def test_arguments(self):
206179 "colsample_bynode" : 1 , "subsample" : 1 , "random_state" : 0 }
207180 self .assertDictEqual (self .editor .get_arguments (), args )
208181
209- @unittest .skipIf (XGBRFClassifier is None , "Missing 'xgboost' package" )
210182 def test_learner_parameters (self ):
211183 params = (("Method" ,
212184 "Extreme Gradient Boosting Random Forest (xgboost)" ),
@@ -221,7 +193,6 @@ def test_learner_parameters(self):
221193 ("Fraction of features for each split" , 1 ))
222194 self .assertTupleEqual (self .editor .get_learner_parameters (), params )
223195
224- @unittest .skipIf (XGBRFClassifier is None , "Missing 'xgboost' package" )
225196 def test_default_parameters_cls (self ):
226197 data = Table ("heart_disease" )
227198 booster = XGBRFClassifier ()
@@ -239,7 +210,6 @@ def test_default_parameters_cls(self):
239210 self .assertEqual (int (tp ["colsample_bylevel" ]), self .editor .colsample_bylevel )
240211 self .assertEqual (int (tp ["colsample_bynode" ]), self .editor .colsample_bynode )
241212
242- @unittest .skipIf (XGBRFRegressor is None , "Missing 'xgboost' package" )
243213 def test_default_parameters_reg (self ):
244214 data = Table ("housing" )
245215 booster = XGBRFRegressor ()
@@ -266,7 +236,6 @@ def test_arguments(self):
266236 "reg_lambda" : 3 , "colsample_bylevel" : 1 , "random_state" : 0 }
267237 self .assertDictEqual (self .editor .get_arguments (), args )
268238
269- @unittest .skipIf (CatGBClassifier is None , "Missing 'catboost' package" )
270239 def test_learner_parameters (self ):
271240 params = (("Method" , "Gradient Boosting (catboost)" ),
272241 ("Number of trees" , 100 ),
@@ -277,7 +246,6 @@ def test_learner_parameters(self):
277246 ("Fraction of features for each tree" , 1 ))
278247 self .assertTupleEqual (self .editor .get_learner_parameters (), params )
279248
280- @unittest .skipIf (CatGBClassifier is None , "Missing 'catboost' package" )
281249 def test_default_parameters_cls (self ):
282250 data = Table ("heart_disease" )
283251 booster = CatGBClassifier ()
@@ -291,7 +259,6 @@ def test_default_parameters_cls(self):
291259 self .assertEqual (self .editor .learning_rate , 0.3 )
292260 # params["learning_rate"] is automatically defined so don't test it
293261
294- @unittest .skipIf (CatGBRegressor is None , "Missing 'catboost' package" )
295262 def test_default_parameters_reg (self ):
296263 data = Table ("housing" )
297264 booster = CatGBRegressor ()
@@ -305,6 +272,7 @@ def test_default_parameters_reg(self):
305272 self .assertEqual (self .editor .learning_rate , 0.3 )
306273 # params["learning_rate"] is automatically defined so don't test it
307274
275+
308276class TestOWGradientBoosting (WidgetTest , WidgetLearnerTestMixin ):
309277 def setUp (self ):
310278 self .widget = self .create_widget (OWGradientBoosting ,
@@ -328,7 +296,6 @@ def test_datasets(self):
328296 for ds in datasets .datasets ():
329297 self .send_signal (self .widget .Inputs .data , ds )
330298
331- @unittest .skipIf (XGBClassifier is None , "Missing 'xgboost' package" )
332299 def test_xgb_params (self ):
333300 simulate .combobox_activate_index (self .widget .controls .method_index , 1 )
334301 editor = self .widget .editor
@@ -350,27 +317,11 @@ def test_xgb_params(self):
350317 def test_methods (self ):
351318 self .send_signal (self .widget .Inputs .data , self .data )
352319 method_cb = self .widget .controls .method_index
353- for i , (cls , _ , _ ) in enumerate (LearnerItemModel .LEARNERS ):
354- if cls is None :
355- continue
320+ for i , cls in enumerate (LearnerItemModel .LEARNERS ):
356321 simulate .combobox_activate_index (method_cb , i )
357322 self .click_apply ()
358323 self .assertIsInstance (self .widget .learner , cls )
359324
360- def test_missing_lib (self ):
361- modules = {k : v for k , v in sys .modules .items ()
362- if "orange" not in k .lower ()} # retain built-ins
363- modules ["xgboost" ] = None
364- modules ["catboost" ] = None
365- # pylint: disable=reimported,redefined-outer-name
366- # pylint: disable=import-outside-toplevel
367- with patch .dict (sys .modules , modules , clear = True ):
368- from Orange .widgets .model .owgradientboosting import \
369- OWGradientBoosting
370- widget = self .create_widget (OWGradientBoosting ,
371- stored_settings = {"method_index" : 3 })
372- self .assertEqual (widget .method_index , 0 )
373-
374325
375326if __name__ == "__main__" :
376327 unittest .main ()
0 commit comments