Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions Orange/widgets/evaluate/owconfusionmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def set_results(self, results):

data = None
if results is not None and results.data is not None:
data = results.data
data = results.data[results.row_indices]

if data is not None and not data.domain.has_discrete_class:
self.Error.no_regression()
Expand Down Expand Up @@ -374,9 +374,8 @@ def commit(self):
data.name = learner_name

if selected:
row_indices = self.results.row_indices[selected]
annotated_data = create_annotated_table(data, row_indices)
data = data[row_indices]
annotated_data = create_annotated_table(data, selected)
data = data[selected]
else:
annotated_data = create_annotated_table(data, [])
data = None
Expand Down
25 changes: 19 additions & 6 deletions Orange/widgets/evaluate/tests/test_owconfusionmatrix.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# pylint: disable=missing-docstring
import numpy as np

from Orange.data import Table
from Orange.classification import NaiveBayesLearner, TreeLearner
from Orange.regression import MeanLearner
from Orange.evaluation.testing import CrossValidation, TestOnTrainingData
from Orange.evaluation.testing import CrossValidation, TestOnTrainingData, \
ShuffleSplit
from Orange.widgets.evaluate.owconfusionmatrix import OWConfusionMatrix
from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin

Expand All @@ -16,11 +18,11 @@ def setUpClass(cls):

bayes = NaiveBayesLearner()
tree = TreeLearner()
iris = cls.data
cls.iris = cls.data
titanic = Table("titanic")
common = dict(k=3, store_data=True)
cls.results_1_iris = CrossValidation(iris, [bayes], **common)
cls.results_2_iris = CrossValidation(iris, [bayes, tree], **common)
cls.results_1_iris = CrossValidation(cls.iris, [bayes], **common)
cls.results_2_iris = CrossValidation(cls.iris, [bayes, tree], **common)
cls.results_2_titanic = CrossValidation(titanic, [bayes, tree],
**common)

Expand Down Expand Up @@ -59,8 +61,7 @@ def _select_data(self):
def test_show_error_on_regression(self):
"""On regression data, the widget must show error"""
housing = Table("housing")
results = TestOnTrainingData(housing, [MeanLearner()])
results.data = housing
results = TestOnTrainingData(housing, [MeanLearner()], store_data=True)
self.send_signal("Evaluation Results", results)
self.assertTrue(self.widget.Error.no_regression.is_shown())
self.send_signal("Evaluation Results", None)
Expand All @@ -69,3 +70,15 @@ def test_show_error_on_regression(self):
self.assertTrue(self.widget.Error.no_regression.is_shown())
self.send_signal("Evaluation Results", self.results_1_iris)
self.assertFalse(self.widget.Error.no_regression.is_shown())

def test_row_indices(self):
"""Map data instances when using random shuffling"""
results = ShuffleSplit(self.iris, [NaiveBayesLearner()],
store_data=True)
self.send_signal("Evaluation Results", results)
self.widget.select_correct()
selected = self.get_output("Selected Data")
correct = np.equal(results.actual, results.predicted)[0]
correct_indices = results.row_indices[correct]
self.assertSetEqual(set(self.iris[correct_indices].ids),
set(selected.ids))