Skip to content

Commit 029d3d4

Browse files
authored
Merge pull request #2940 from PrimozGodec/testl-unittests
Test and Score: Unittests that test the correctness of scores
2 parents 2d460c1 + 8fbc19b commit 029d3d4

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

Orange/widgets/evaluate/tests/test_owtestlearners.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def setUp(self):
2929
super().setUp()
3030
self.widget = self.create_widget(OWTestLearners) # type: OWTestLearners
3131

32+
self.scores_domain = Domain(
33+
[ContinuousVariable("a"), ContinuousVariable("b")],
34+
[DiscreteVariable("c", values=["y", "n"])])
35+
36+
self.scores_table_values = [[1, 1, 1.23, 23.8], [1., 2., 3., 4.]]
37+
3238
def tearDown(self):
3339
self.widget.onDeleteWidget()
3440
super().tearDown()
@@ -321,6 +327,100 @@ def __call__(self, data):
321327

322328
self.widget.hide()
323329

330+
def _retrieve_scores(self):
331+
w = self.widget
332+
auc = w.view.model().item(0, 1).text()
333+
auc = float(auc) if auc != "" else None
334+
ca = float(w.view.model().item(0, 2).text())
335+
f1 = float(w.view.model().item(0, 3).text())
336+
precision = float(w.view.model().item(0, 4).text())
337+
recall = float(w.view.model().item(0, 5).text())
338+
return auc, ca, f1, precision, recall
339+
340+
def _test_scores(self, train_data, test_data, learner, sampling, n_folds):
341+
w = self.widget #: OWTestLearners
342+
w.controls.resampling.buttons[sampling].click()
343+
if n_folds is not None:
344+
w.n_folds = n_folds
345+
346+
self.send_signal(self.widget.Inputs.train_data, train_data)
347+
if test_data is not None:
348+
self.send_signal(self.widget.Inputs.test_data, test_data)
349+
self.send_signal(self.widget.Inputs.learner, learner, 0, wait=5000)
350+
return self._retrieve_scores()
351+
352+
def test_scores_constant_all_same(self):
353+
table = Table(
354+
self.scores_domain,
355+
list(zip(*self.scores_table_values + [list("yyyy")]))
356+
)
357+
358+
self.assertTupleEqual(self._test_scores(
359+
table, table, ConstantLearner(), OWTestLearners.TestOnTest, None),
360+
(None, 1, 1, 1, 1))
361+
362+
def test_scores_log_reg_overfitted(self):
363+
table = Table(
364+
self.scores_domain,
365+
list(zip(*self.scores_table_values + [list("yyyn")]))
366+
)
367+
368+
self.assertTupleEqual(self._test_scores(
369+
table, table, LogisticRegressionLearner(),
370+
OWTestLearners.TestOnTest, None),
371+
(1, 1, 1, 1, 1))
372+
373+
def test_scores_log_reg_bad(self):
374+
table_train = Table(
375+
self.scores_domain,
376+
list(zip(*self.scores_table_values + [list("nnny")]))
377+
)
378+
table_test = Table(
379+
self.scores_domain,
380+
list(zip(*self.scores_table_values + [list("yyyn")]))
381+
)
382+
383+
self.assertTupleEqual(self._test_scores(
384+
table_train, table_test, LogisticRegressionLearner(),
385+
OWTestLearners.TestOnTest, None),
386+
(0, 0, 0, 0, 0))
387+
388+
def test_scores_log_reg_bad2(self):
389+
table_train = Table(
390+
self.scores_domain,
391+
list(zip(*(self.scores_table_values + [list("nnyy")]))))
392+
table_test = Table(
393+
self.scores_domain,
394+
list(zip(*(self.scores_table_values + [list("yynn")]))))
395+
self.assertTupleEqual(self._test_scores(
396+
table_train, table_test, LogisticRegressionLearner(),
397+
OWTestLearners.TestOnTest, None),
398+
(0, 0, 0, 0, 0))
399+
400+
def test_scores_log_reg_advanced(self):
401+
table_train = Table(
402+
self.scores_domain, list(zip(
403+
[1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yyynn"))
404+
)
405+
table_test = Table(
406+
self.scores_domain, list(zip(
407+
[1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yynnn"))
408+
)
409+
410+
self.assertTupleEqual(self._test_scores(
411+
table_train, table_test, LogisticRegressionLearner(),
412+
OWTestLearners.TestOnTest, None),
413+
(0.667, 0.8, 0.8, 0.867, 0.8))
414+
415+
def test_scores_cross_validation(self):
416+
"""
417+
Test more than two classes and cross-validation
418+
"""
419+
self.assertTupleEqual(self._test_scores(
420+
Table("iris")[::15], None, LogisticRegressionLearner(),
421+
OWTestLearners.KFold, 0),
422+
(0.917, 0.7, 0.6, 0.55, 0.7))
423+
324424

325425
class TestHelpers(unittest.TestCase):
326426
def test_results_one_vs_rest(self):

0 commit comments

Comments
 (0)