@@ -29,6 +29,12 @@ def setUp(self):
2929 super ().setUp ()
3030 self .widget = self .create_widget (OWTestLearners ) # type: OWTestLearners
3131
32+ self .scores_domain = Domain (
33+ [ContinuousVariable ("a" ), ContinuousVariable ("b" )],
34+ [DiscreteVariable ("c" , values = ["y" , "n" ])])
35+
36+ self .scores_table_values = [[1 , 1 , 1.23 , 23.8 ], [1. , 2. , 3. , 4. ]]
37+
3238 def tearDown (self ):
3339 self .widget .onDeleteWidget ()
3440 super ().tearDown ()
@@ -321,6 +327,100 @@ def __call__(self, data):
321327
322328 self .widget .hide ()
323329
330+ def _retrieve_scores (self ):
331+ w = self .widget
332+ auc = w .view .model ().item (0 , 1 ).text ()
333+ auc = float (auc ) if auc != "" else None
334+ ca = float (w .view .model ().item (0 , 2 ).text ())
335+ f1 = float (w .view .model ().item (0 , 3 ).text ())
336+ precision = float (w .view .model ().item (0 , 4 ).text ())
337+ recall = float (w .view .model ().item (0 , 5 ).text ())
338+ return auc , ca , f1 , precision , recall
339+
340+ def _test_scores (self , train_data , test_data , learner , sampling , n_folds ):
341+ w = self .widget #: OWTestLearners
342+ w .controls .resampling .buttons [sampling ].click ()
343+ if n_folds is not None :
344+ w .n_folds = n_folds
345+
346+ self .send_signal (self .widget .Inputs .train_data , train_data )
347+ if test_data is not None :
348+ self .send_signal (self .widget .Inputs .test_data , test_data )
349+ self .send_signal (self .widget .Inputs .learner , learner , 0 , wait = 5000 )
350+ return self ._retrieve_scores ()
351+
352+ def test_scores_constant_all_same (self ):
353+ table = Table (
354+ self .scores_domain ,
355+ list (zip (* self .scores_table_values + [list ("yyyy" )]))
356+ )
357+
358+ self .assertTupleEqual (self ._test_scores (
359+ table , table , ConstantLearner (), OWTestLearners .TestOnTest , None ),
360+ (None , 1 , 1 , 1 , 1 ))
361+
362+ def test_scores_log_reg_overfitted (self ):
363+ table = Table (
364+ self .scores_domain ,
365+ list (zip (* self .scores_table_values + [list ("yyyn" )]))
366+ )
367+
368+ self .assertTupleEqual (self ._test_scores (
369+ table , table , LogisticRegressionLearner (),
370+ OWTestLearners .TestOnTest , None ),
371+ (1 , 1 , 1 , 1 , 1 ))
372+
373+ def test_scores_log_reg_bad (self ):
374+ table_train = Table (
375+ self .scores_domain ,
376+ list (zip (* self .scores_table_values + [list ("nnny" )]))
377+ )
378+ table_test = Table (
379+ self .scores_domain ,
380+ list (zip (* self .scores_table_values + [list ("yyyn" )]))
381+ )
382+
383+ self .assertTupleEqual (self ._test_scores (
384+ table_train , table_test , LogisticRegressionLearner (),
385+ OWTestLearners .TestOnTest , None ),
386+ (0 , 0 , 0 , 0 , 0 ))
387+
388+ def test_scores_log_reg_bad2 (self ):
389+ table_train = Table (
390+ self .scores_domain ,
391+ list (zip (* (self .scores_table_values + [list ("nnyy" )]))))
392+ table_test = Table (
393+ self .scores_domain ,
394+ list (zip (* (self .scores_table_values + [list ("yynn" )]))))
395+ self .assertTupleEqual (self ._test_scores (
396+ table_train , table_test , LogisticRegressionLearner (),
397+ OWTestLearners .TestOnTest , None ),
398+ (0 , 0 , 0 , 0 , 0 ))
399+
400+ def test_scores_log_reg_advanced (self ):
401+ table_train = Table (
402+ self .scores_domain , list (zip (
403+ [1 , 1 , 1.23 , 23.8 , 5. ], [1. , 2. , 3. , 4. , 3. ], "yyynn" ))
404+ )
405+ table_test = Table (
406+ self .scores_domain , list (zip (
407+ [1 , 1 , 1.23 , 23.8 , 5. ], [1. , 2. , 3. , 4. , 3. ], "yynnn" ))
408+ )
409+
410+ self .assertTupleEqual (self ._test_scores (
411+ table_train , table_test , LogisticRegressionLearner (),
412+ OWTestLearners .TestOnTest , None ),
413+ (0.667 , 0.8 , 0.8 , 0.867 , 0.8 ))
414+
415+ def test_scores_cross_validation (self ):
416+ """
417+ Test more than two classes and cross-validation
418+ """
419+ self .assertTupleEqual (self ._test_scores (
420+ Table ("iris" )[::15 ], None , LogisticRegressionLearner (),
421+ OWTestLearners .KFold , 0 ),
422+ (0.917 , 0.7 , 0.6 , 0.55 , 0.7 ))
423+
324424
325425class TestHelpers (unittest .TestCase ):
326426 def test_results_one_vs_rest (self ):
0 commit comments