Skip to content

Commit 92bba3f

Browse files
authored
Merge pull request #1720 from janezd/fix-confusion-indices
[FIX] Confusion matrix: Map annotated data through row_indices, add probabi…
2 parents d8c42b0 + 820e66f commit 92bba3f

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

Orange/widgets/evaluate/owconfusionmatrix.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def set_results(self, results):
229229

230230
data = None
231231
if results is not None and results.data is not None:
232-
data = results.data
232+
data = results.data[results.row_indices]
233233

234234
if data is not None and not data.domain.has_discrete_class:
235235
self.Error.no_regression()
@@ -375,9 +375,8 @@ def commit(self):
375375
data.name = learner_name
376376

377377
if selected:
378-
row_indices = self.results.row_indices[selected]
379-
annotated_data = create_annotated_table(data, row_indices)
380-
data = data[row_indices]
378+
annotated_data = create_annotated_table(data, selected)
379+
data = data[selected]
381380
else:
382381
annotated_data = create_annotated_table(data, [])
383382
data = None

Orange/widgets/evaluate/tests/test_owconfusionmatrix.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# pylint: disable=missing-docstring
2+
import numpy as np
23

34
from Orange.data import Table
45
from Orange.classification import NaiveBayesLearner, TreeLearner
56
from Orange.regression import MeanLearner
6-
from Orange.evaluation.testing import CrossValidation, TestOnTrainingData
7+
from Orange.evaluation.testing import CrossValidation, TestOnTrainingData, \
8+
ShuffleSplit
79
from Orange.widgets.evaluate.owconfusionmatrix import OWConfusionMatrix
810
from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin
911

@@ -16,11 +18,11 @@ def setUpClass(cls):
1618

1719
bayes = NaiveBayesLearner()
1820
tree = TreeLearner()
19-
iris = cls.data
21+
cls.iris = cls.data
2022
titanic = Table("titanic")
2123
common = dict(k=3, store_data=True)
22-
cls.results_1_iris = CrossValidation(iris, [bayes], **common)
23-
cls.results_2_iris = CrossValidation(iris, [bayes, tree], **common)
24+
cls.results_1_iris = CrossValidation(cls.iris, [bayes], **common)
25+
cls.results_2_iris = CrossValidation(cls.iris, [bayes, tree], **common)
2426
cls.results_2_titanic = CrossValidation(titanic, [bayes, tree],
2527
**common)
2628

@@ -59,8 +61,7 @@ def _select_data(self):
5961
def test_show_error_on_regression(self):
6062
"""On regression data, the widget must show error"""
6163
housing = Table("housing")
62-
results = TestOnTrainingData(housing, [MeanLearner()])
63-
results.data = housing
64+
results = TestOnTrainingData(housing, [MeanLearner()], store_data=True)
6465
self.send_signal("Evaluation Results", results)
6566
self.assertTrue(self.widget.Error.no_regression.is_shown())
6667
self.send_signal("Evaluation Results", None)
@@ -69,3 +70,15 @@ def test_show_error_on_regression(self):
6970
self.assertTrue(self.widget.Error.no_regression.is_shown())
7071
self.send_signal("Evaluation Results", self.results_1_iris)
7172
self.assertFalse(self.widget.Error.no_regression.is_shown())
73+
74+
def test_row_indices(self):
75+
"""Map data instances when using random shuffling"""
76+
results = ShuffleSplit(self.iris, [NaiveBayesLearner()],
77+
store_data=True)
78+
self.send_signal("Evaluation Results", results)
79+
self.widget.select_correct()
80+
selected = self.get_output("Selected Data")
81+
correct = np.equal(results.actual, results.predicted)[0]
82+
correct_indices = results.row_indices[correct]
83+
self.assertSetEqual(set(self.iris[correct_indices].ids),
84+
set(selected.ids))

0 commit comments

Comments
 (0)