Skip to content

Commit bb4c530

Browse files
committed
OWTestLearners: Cross validation by feature
1 parent 769024c commit bb4c530

File tree

2 files changed

+85
-18
lines changed

2 files changed

+85
-18
lines changed

Orange/widgets/evaluate/owtestlearners.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from AnyQt.QtGui import QStandardItemModel, QStandardItem
1212
from AnyQt.QtCore import Qt, QSize
1313

14-
from Orange.data import Table
14+
from Orange.data import Table, DiscreteVariable
1515
from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT
1616
import Orange.evaluation
1717
import Orange.classification
@@ -22,6 +22,7 @@
2222
from Orange.preprocess.preprocess import Preprocess
2323
from Orange.preprocess import RemoveNaNClasses
2424
from Orange.widgets import widget, gui, settings
25+
from Orange.widgets.utils.itemmodels import DomainModel
2526
from Orange.widgets.widget import OWWidget, Msg
2627

2728
Input = namedtuple(
@@ -137,12 +138,12 @@ class OWTestLearners(OWWidget):
137138
outputs = [("Predictions", Table),
138139
("Evaluation Results", Results)]
139140

140-
settingsHandler = settings.ClassValuesContextHandler()
141+
settingsHandler = settings.PerfectDomainContextHandler(metas_in_res=True)
141142

142143
#: Resampling/testing types
143144
KFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest = 0, 1, 2, 3, 4
144145
#: Numbers of folds
145-
NFolds = [2, 3, 5, 10, 20]
146+
NFolds = [2, 3, 5, 10, 20, "From feature"]
146147
#: Number of repetitions
147148
NRepeats = [2, 3, 5, 10, 20, 50, 100]
148149
#: Sample sizes
@@ -160,6 +161,8 @@ class OWTestLearners(OWWidget):
160161
sample_size = settings.Setting(9)
161162
#: Stratified sampling for Random Sampling
162163
shuffle_stratified = settings.Setting(True)
164+
# CV where nr. of feature values determines nr. of folds
165+
fold_feature = settings.ContextSetting(None)
163166

164167
TARGET_AVERAGE = "(Average over classes)"
165168
class_selection = settings.ContextSetting(TARGET_AVERAGE)
@@ -204,13 +207,18 @@ def __init__(self):
204207

205208
gui.appendRadioButton(rbox, "Cross validation")
206209
ibox = gui.indentedBox(rbox)
207-
gui.comboBox(
210+
self.n_folds_combo = gui.comboBox(
208211
ibox, self, "n_folds", label="Number of folds: ",
209212
items=[str(x) for x in self.NFolds], maximumContentsLength=3,
210213
orientation=Qt.Horizontal, callback=self.kfold_changed)
211-
gui.checkBox(
214+
self.stratified_check = gui.checkBox(
212215
ibox, self, "cv_stratified", "Stratified",
213216
callback=self.kfold_changed)
217+
self.feature_model = DomainModel(
218+
order=DomainModel.METAS, valid_types=DiscreteVariable)
219+
self.features_combo = gui.comboBox(
220+
ibox, self, "fold_feature", model=self.feature_model,
221+
orientation=Qt.Horizontal, callback=self.fold_feature_changed)
214222

215223
gui.appendRadioButton(rbox, "Random sampling")
216224
ibox = gui.indentedBox(rbox)
@@ -257,9 +265,32 @@ def __init__(self):
257265
box = gui.vBox(self.mainArea, "Evaluation Results")
258266
box.layout().addWidget(self.view)
259267

268+
@property
269+
def kfold_feature_index(self):
270+
return len(self.NFolds) - 1
271+
260272
def sizeHint(self):
261273
return QSize(780, 1)
262274

275+
def __hide_show_feature_combo(self):
276+
cv_feature = self.n_folds == self.kfold_feature_index
277+
self.stratified_check.setVisible(not cv_feature)
278+
self.features_combo.setVisible(cv_feature)
279+
if self.fold_feature is None and cv_feature and self.feature_model:
280+
self.fold_feature = self.feature_model[0]
281+
282+
def _update_controls(self):
283+
self.fold_feature = None
284+
self.feature_model.set_domain(None)
285+
if self.data:
286+
self.feature_model.set_domain(self.data.domain)
287+
enable = bool(self.feature_model)
288+
item = self.n_folds_combo.model().item(self.kfold_feature_index)
289+
item.setEnabled(enable)
290+
if self.n_folds == self.kfold_feature_index and not enable:
291+
self.n_folds = 3
292+
self.__hide_show_feature_combo()
293+
263294
def set_learner(self, learner, key):
264295
"""
265296
Set the input `learner` for `key`.
@@ -310,9 +341,10 @@ def set_train_data(self, data):
310341

311342
self.data = data
312343
self.closeContext()
344+
self._update_controls()
313345
if data is not None:
314346
self._update_class_selection()
315-
self.openContext(data.domain.class_var)
347+
self.openContext(data.domain)
316348
self._invalidate()
317349

318350
def set_test_data(self, data):
@@ -372,6 +404,10 @@ def handleNewSignals(self):
372404

373405
def kfold_changed(self):
374406
self.resampling = OWTestLearners.KFold
407+
self.__hide_show_feature_combo()
408+
self._param_changed()
409+
410+
def fold_feature_changed(self):
375411
self._param_changed()
376412

377413
def shuffle_split_changed(self):
@@ -429,17 +465,22 @@ def update_progress(finished):
429465

430466
with self.progressBar():
431467
try:
432-
folds = self.NFolds[self.n_folds]
433468
if self.resampling == OWTestLearners.KFold:
434-
if len(self.data) < folds:
435-
self.Error.too_many_folds()
436-
return
437-
warnings = []
438-
results = Orange.evaluation.CrossValidation(
439-
self.data, learners, k=folds,
440-
random_state=rstate, warnings=warnings, **common_args)
441-
if warnings:
442-
self.warning(warnings[0])
469+
if self.n_folds == self.kfold_feature_index:
470+
results = Orange.evaluation.CrossValidationFeature(
471+
self.data, learners, self.fold_feature,
472+
**common_args)
473+
else:
474+
folds = self.NFolds[self.n_folds]
475+
if len(self.data) < folds:
476+
self.Error.too_many_folds()
477+
return
478+
warnings = []
479+
results = Orange.evaluation.CrossValidation(
480+
self.data, learners, k=folds,
481+
random_state=rstate, warnings=warnings, **common_args)
482+
if warnings:
483+
self.warning(warnings[0])
443484
elif self.resampling == OWTestLearners.LeaveOneOut:
444485
results = Orange.evaluation.LeaveOneOut(
445486
self.data, learners, **common_args)

Orange/widgets/evaluate/tests/test_owtestlearners.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
import unittest
55

6-
from Orange.data import Table
6+
from Orange.data import Table, Domain
77
from Orange.classification import MajorityLearner
88
from Orange.regression import MeanLearner
9+
from Orange.modelling import ConstantLearner
910

1011
from Orange.evaluation import Results, TestOnTestData
1112
from Orange.widgets.tests.base import WidgetTest
@@ -38,6 +39,31 @@ def test_basic(self):
3839
self.assertIsNotNone(res.domain)
3940
self.assertIsNotNone(res.data)
4041

42+
def test_feature_cv_combo(self):
43+
data, i = Table("iris"), 5
44+
attrs = data.domain.attributes
45+
domain = Domain(attrs[:-1], attrs[-1], data.domain.class_vars)
46+
data_with_disc_metas = Table.from_table(domain, data)
47+
48+
self.send_signal("Learner", ConstantLearner(), 0)
49+
self.send_signal("Data", data)
50+
self.assertFalse(self.widget.n_folds_combo.model().item(i).isEnabled())
51+
self.assertTrue(self.widget.features_combo.isHidden())
52+
self.assertFalse(self.widget.stratified_check.isHidden())
53+
54+
self.send_signal("Data", data_with_disc_metas)
55+
self.assertTrue(self.widget.n_folds_combo.model().item(i).isEnabled())
56+
self.widget.n_folds_combo.activated.emit(i)
57+
self.widget.n_folds_combo.setCurrentIndex(i)
58+
self.assertTrue(self.widget.stratified_check.isHidden())
59+
self.assertFalse(self.widget.features_combo.isHidden())
60+
self.assertEqual(len(self.widget.features_combo.model()), 1)
61+
62+
self.send_signal("Data", None)
63+
self.assertFalse(self.widget.n_folds_combo.model().item(i).isEnabled())
64+
self.assertTrue(self.widget.features_combo.isHidden())
65+
self.assertFalse(self.widget.stratified_check.isHidden())
66+
4167

4268
class TestHelpers(unittest.TestCase):
4369
def test_results_one_vs_rest(self):
@@ -64,4 +90,4 @@ def test_results_one_vs_rest(self):
6490

6591
np.testing.assert_equal(r1.row_indices, res.row_indices)
6692
np.testing.assert_equal(r2.row_indices, res.row_indices)
67-
np.testing.assert_equal(r3.row_indices, res.row_indices)
93+
np.testing.assert_equal(r3.row_indices, res.row_indices)

0 commit comments

Comments
 (0)