Skip to content

Commit 4830c26

Browse files
authored
Merge pull request #1996 from ales-erjavec/fixes/test-learners-fix-one-vs-rest
[FIX] Test Learners: Fix AUC for selected single target class
2 parents 8ca36a3 + 4d5d38a commit 4830c26

File tree

2 files changed

+100
-3
lines changed

2 files changed

+100
-3
lines changed

Orange/widgets/evaluate/owtestlearners.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -709,11 +709,41 @@ def results_merge(results):
709709

710710

711711
def results_one_vs_rest(results, pos_index):
712+
from Orange.preprocess.transformation import Indicator
712713
actual = results.actual == pos_index
713714
predicted = results.predicted == pos_index
714-
return Orange.evaluation.Results(
715-
nmethods=1, domain=results.domain,
716-
actual=actual, predicted=predicted)
715+
if results.probabilities is not None:
716+
c = results.probabilities.shape[2]
717+
assert c >= 2
718+
neg_indices = [i for i in range(c) if i != pos_index]
719+
pos_prob = results.probabilities[:, :, [pos_index]]
720+
neg_prob = np.sum(results.probabilities[:, :, neg_indices],
721+
axis=2, keepdims=True)
722+
probabilities = np.dstack((neg_prob, pos_prob))
723+
else:
724+
probabilities = None
725+
726+
res = Orange.evaluation.Results()
727+
res.actual = actual
728+
res.predicted = predicted
729+
res.folds = results.folds
730+
res.row_indices = results.row_indices
731+
res.probabilities = probabilities
732+
733+
value = results.domain.class_var.values[pos_index]
734+
class_var = Orange.data.DiscreteVariable(
735+
"I({}=={})".format(results.domain.class_var.name, value),
736+
values=["False", "True"],
737+
compute_value=Indicator(results.domain.class_var, pos_index)
738+
)
739+
domain = Orange.data.Domain(
740+
results.domain.attributes,
741+
[class_var],
742+
results.domain.metas
743+
)
744+
res.data = None
745+
res.domain = domain
746+
return res
717747

718748

719749
def main(argv=None):
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# pylint: disable=missing-docstring
2+
import numpy as np
3+
4+
import unittest
5+
6+
from Orange.data import Table
7+
from Orange.classification import MajorityLearner
8+
from Orange.regression import MeanLearner
9+
10+
from Orange.evaluation import Results, TestOnTestData
11+
from Orange.widgets.tests.base import WidgetTest
12+
from Orange.widgets.evaluate.owtestlearners import OWTestLearners
13+
from Orange.widgets.evaluate import owtestlearners
14+
15+
16+
class TestOWTestLearners(WidgetTest):
17+
def setUp(self):
18+
super().setUp()
19+
self.widget = self.create_widget(OWTestLearners) # type: OWTestLearners
20+
21+
def test_basic(self):
22+
data = Table("iris")[::3]
23+
self.send_signal("Data", data)
24+
self.send_signal("Learner", MajorityLearner(), 0)
25+
res = self.get_output("Evaluation Results")
26+
self.assertIsInstance(res, Results)
27+
self.assertIsNotNone(res.domain)
28+
self.assertIsNotNone(res.data)
29+
self.assertIsNotNone(res.probabilities)
30+
31+
self.send_signal("Learner", None, 0)
32+
33+
data = Table("housing")[::10]
34+
self.send_signal("Data", data)
35+
self.send_signal("Learner", MeanLearner(), 0)
36+
res = self.get_output("Evaluation Results")
37+
self.assertIsInstance(res, Results)
38+
self.assertIsNotNone(res.domain)
39+
self.assertIsNotNone(res.data)
40+
41+
42+
class TestHelpers(unittest.TestCase):
43+
def test_results_one_vs_rest(self):
44+
data = Table("lenses")
45+
learners = [MajorityLearner()]
46+
res = TestOnTestData(data[1::2], data[::2], learners=learners)
47+
r1 = owtestlearners.results_one_vs_rest(res, pos_index=0)
48+
r2 = owtestlearners.results_one_vs_rest(res, pos_index=1)
49+
r3 = owtestlearners.results_one_vs_rest(res, pos_index=2)
50+
51+
np.testing.assert_almost_equal(np.sum(r1.probabilities, axis=2), 1.0)
52+
np.testing.assert_almost_equal(np.sum(r2.probabilities, axis=2), 1.0)
53+
np.testing.assert_almost_equal(np.sum(r3.probabilities, axis=2), 1.0)
54+
55+
np.testing.assert_almost_equal(
56+
r1.probabilities[:, :, 1] +
57+
r2.probabilities[:, :, 1] +
58+
r3.probabilities[:, :, 1],
59+
1.0
60+
)
61+
self.assertEqual(r1.folds, res.folds)
62+
self.assertEqual(r2.folds, res.folds)
63+
self.assertEqual(r3.folds, res.folds)
64+
65+
np.testing.assert_equal(r1.row_indices, res.row_indices)
66+
np.testing.assert_equal(r2.row_indices, res.row_indices)
67+
np.testing.assert_equal(r3.row_indices, res.row_indices)

0 commit comments

Comments
 (0)