2222 MajorityLearner ,
2323 RandomForestLearner , SimpleTreeLearner , SoftmaxRegressionLearner ,
2424 SVMLearner , LinearSVMLearner , OneClassSVMLearner , TreeLearner , KNNLearner ,
25- SimpleRandomForestLearner , EllipticEnvelopeLearner ,
26- SGDClassificationLearner )
25+ SimpleRandomForestLearner , EllipticEnvelopeLearner )
2726from Orange .classification .rules import _RuleLearner
2827from Orange .data import (ContinuousVariable , DiscreteVariable ,
2928 Domain , Table )
3332from Orange .tests import test_filename
3433
3534
35+ def all_learners ():
36+ classification_modules = pkgutil .walk_packages (
37+ path = Orange .classification .__path__ ,
38+ prefix = "Orange.classification." ,
39+ onerror = lambda x : None )
40+ for importer , modname , ispkg in classification_modules :
41+ try :
42+ module = pkgutil .importlib .import_module (modname )
43+ except ImportError :
44+ continue
45+
46+ for name , class_ in inspect .getmembers (module , inspect .isclass ):
47+ if (issubclass (class_ , Learner ) and
48+ not name .startswith ('_' ) and
49+ 'base' not in class_ .__module__ ):
50+ yield class_
51+
52+
3653class MultiClassTest (unittest .TestCase ):
3754 def test_unsupported (self ):
3855 nrows = 20
@@ -194,21 +211,70 @@ def test_incompatible_domain(self):
194211
195212 def test_result_shape (self ):
196213 """
197- This test function will be extended for all models in on of the
198- following pull requests.
214+ Test if the results shapes are correct
199215 """
200216 iris = Table ('iris' )
201- learner = SGDClassificationLearner ()
217+ for learner in all_learners ():
218+ with self .subTest (learner .__name__ ):
219+ # model trained on only one value (but three in the domain)
220+ try :
221+ model = learner ()(iris [0 :50 ])
222+ except TypeError :
223+ # cannot be tested with default parameters
224+ continue
225+
226+ res = model (iris [0 :50 ])
227+ self .assertTupleEqual ((50 ,), res .shape )
228+
229+ # probabilities must still be for three classes
230+ res = model (iris [0 :50 ], model .Probs )
231+ self .assertTupleEqual ((50 , 3 ), res .shape )
202232
203- # model trained on only one value (but three in the domain)
204- model = learner (iris )
233+ # model trained on all classes and predicting with one class
234+ try :
235+ model = learner ()(iris [0 :50 ])
236+ except TypeError :
237+ # cannot be tested with default parameters
238+ continue
239+ res = model (iris [0 :50 ], model .Probs )
240+ self .assertTupleEqual ((50 , 3 ), res .shape )
205241
206- res = model (iris [0 :50 ])
207- self .assertTupleEqual ((50 ,), res .shape )
242+ def test_result_shape_numpy (self ):
243+ """
244+ Test whether results shapes are correct when testing on numpy data
245+ """
246+ iris = Table ('iris' )
247+ for learner in all_learners ():
248+ with self .subTest (learner .__name__ ):
249+ if learner .__name__ == "CN2SDLearner" :
250+ # TODO: fix CN2SDLearner
251+ continue
252+ try :
253+ model = learner ()(iris )
254+ except TypeError :
255+ # cannot be tested with default parameters
256+ continue
257+ transformed_iris = model .data_to_model_domain (iris )
208258
209- # probabilities must still be for three classes
210- res = model (iris [0 :50 ], model .Probs )
211- self .assertTupleEqual ((50 , 3 ), res .shape )
259+ res = model (transformed_iris .X [0 :5 ])
260+ self .assertTupleEqual ((5 ,), res .shape )
261+
262+ res = model (transformed_iris .X [0 :1 ], model .Probs )
263+ self .assertTupleEqual ((1 , 3 ), res .shape )
264+
265+ def test_fit_one_class (self ):
266+ """
267+ Test whether the fitting with one class only pass - before it failed
268+ for some models.
269+ """
270+ iris = Table ('iris' )
271+ for learner in all_learners ():
272+ with self .subTest (learner .__name__ ):
273+ try :
274+ model = learner ()(iris [:50 ])
275+ except TypeError :
276+ # cannot be tested with default parameters
277+ continue
212278
213279
214280class ExpandProbabilitiesTest (unittest .TestCase ):
@@ -309,7 +375,7 @@ def test_unknown(self):
309375
310376 def test_missing_class (self ):
311377 table = Table (test_filename ("datasets/adult_sample_missing" ))
312- for learner in LearnerAccessibility (). all_learners ():
378+ for learner in all_learners ():
313379 try :
314380 learner = learner ()
315381 if isinstance (learner , NuSVMLearner ):
@@ -330,33 +396,15 @@ def setUp(self):
330396 # Convergence warnings are irrelevant for these tests
331397 warnings .filterwarnings ("ignore" , ".*" , ConvergenceWarning )
332398
333-
334- def all_learners (self ):
335- classification_modules = pkgutil .walk_packages (
336- path = Orange .classification .__path__ ,
337- prefix = "Orange.classification." ,
338- onerror = lambda x : None )
339- for importer , modname , ispkg in classification_modules :
340- try :
341- module = pkgutil .importlib .import_module (modname )
342- except ImportError :
343- continue
344-
345- for name , class_ in inspect .getmembers (module , inspect .isclass ):
346- if (issubclass (class_ , Learner ) and
347- not name .startswith ('_' ) and
348- 'base' not in class_ .__module__ ):
349- yield class_
350-
351399 def test_all_learners_accessible_in_Orange_classification_namespace (self ):
352- for learner in self . all_learners ():
400+ for learner in all_learners ():
353401 if not hasattr (Orange .classification , learner .__name__ ):
354402 self .fail ("%s is not visible in Orange.classification"
355403 " namespace" % learner .__name__ )
356404
357405 def test_all_models_work_after_unpickling (self ):
358406 datasets = [Table ('iris' ), Table ('titanic' )]
359- for learner in list (self . all_learners ()):
407+ for learner in list (all_learners ()):
360408 try :
361409 learner = learner ()
362410 except Exception :
@@ -381,7 +429,7 @@ def test_all_models_work_after_unpickling(self):
381429 % (learner .__class__ .__name__ , ds .name ))
382430
383431 def test_adequacy_all_learners (self ):
384- for learner in self . all_learners ():
432+ for learner in all_learners ():
385433 try :
386434 learner = learner ()
387435 table = Table ("housing" )
@@ -391,7 +439,7 @@ def test_adequacy_all_learners(self):
391439 continue
392440
393441 def test_adequacy_all_learners_multiclass (self ):
394- for learner in self . all_learners ():
442+ for learner in all_learners ():
395443 try :
396444 learner = learner ()
397445 table = Table (test_filename ("datasets/test8.tab" ))
0 commit comments