1111from AnyQt .QtGui import QStandardItemModel , QStandardItem
1212from AnyQt .QtCore import Qt , QSize
1313
14- from Orange .data import Table
14+ from Orange .data import Table , DiscreteVariable
1515from Orange .data .sql .table import SqlTable , AUTO_DL_LIMIT
1616import Orange .evaluation
1717import Orange .classification
2222from Orange .preprocess .preprocess import Preprocess
2323from Orange .preprocess import RemoveNaNClasses
2424from Orange .widgets import widget , gui , settings
25+ from Orange .widgets .utils .itemmodels import DomainModel
2526from Orange .widgets .widget import OWWidget , Msg
2627
2728Input = namedtuple (
@@ -137,12 +138,12 @@ class OWTestLearners(OWWidget):
137138 outputs = [("Predictions" , Table ),
138139 ("Evaluation Results" , Results )]
139140
140- settingsHandler = settings .ClassValuesContextHandler ( )
141+ settingsHandler = settings .PerfectDomainContextHandler ( metas_in_res = True )
141142
142143 #: Resampling/testing types
143144 KFold , ShuffleSplit , LeaveOneOut , TestOnTrain , TestOnTest = 0 , 1 , 2 , 3 , 4
144145 #: Numbers of folds
145- NFolds = [2 , 3 , 5 , 10 , 20 ]
146+ NFolds = [2 , 3 , 5 , 10 , 20 , "From feature" ]
146147 #: Number of repetitions
147148 NRepeats = [2 , 3 , 5 , 10 , 20 , 50 , 100 ]
148149 #: Sample sizes
@@ -160,6 +161,8 @@ class OWTestLearners(OWWidget):
160161 sample_size = settings .Setting (9 )
161162 #: Stratified sampling for Random Sampling
162163 shuffle_stratified = settings .Setting (True )
164+ # CV where nr. of feature values determines nr. of folds
165+ fold_feature = settings .ContextSetting (None )
163166
164167 TARGET_AVERAGE = "(Average over classes)"
165168 class_selection = settings .ContextSetting (TARGET_AVERAGE )
@@ -204,13 +207,18 @@ def __init__(self):
204207
205208 gui .appendRadioButton (rbox , "Cross validation" )
206209 ibox = gui .indentedBox (rbox )
207- gui .comboBox (
210+ self . n_folds_combo = gui .comboBox (
208211 ibox , self , "n_folds" , label = "Number of folds: " ,
209212 items = [str (x ) for x in self .NFolds ], maximumContentsLength = 3 ,
210213 orientation = Qt .Horizontal , callback = self .kfold_changed )
211- gui .checkBox (
214+ self . stratified_check = gui .checkBox (
212215 ibox , self , "cv_stratified" , "Stratified" ,
213216 callback = self .kfold_changed )
217+ self .feature_model = DomainModel (
218+ order = DomainModel .METAS , valid_types = DiscreteVariable )
219+ self .features_combo = gui .comboBox (
220+ ibox , self , "fold_feature" , model = self .feature_model ,
221+ orientation = Qt .Horizontal , callback = self .fold_feature_changed )
214222
215223 gui .appendRadioButton (rbox , "Random sampling" )
216224 ibox = gui .indentedBox (rbox )
@@ -257,9 +265,32 @@ def __init__(self):
257265 box = gui .vBox (self .mainArea , "Evaluation Results" )
258266 box .layout ().addWidget (self .view )
259267
268+ @property
269+ def kfold_feature_index (self ):
270+ return len (self .NFolds ) - 1
271+
260272 def sizeHint (self ):
261273 return QSize (780 , 1 )
262274
275+ def __hide_show_feature_combo (self ):
276+ cv_feature = self .n_folds == self .kfold_feature_index
277+ self .stratified_check .setVisible (not cv_feature )
278+ self .features_combo .setVisible (cv_feature )
279+ if self .fold_feature is None and cv_feature and self .feature_model :
280+ self .fold_feature = self .feature_model [0 ]
281+
282+ def _update_controls (self ):
283+ self .fold_feature = None
284+ self .feature_model .set_domain (None )
285+ if self .data :
286+ self .feature_model .set_domain (self .data .domain )
287+ enable = bool (self .feature_model )
288+ item = self .n_folds_combo .model ().item (self .kfold_feature_index )
289+ item .setEnabled (enable )
290+ if self .n_folds == self .kfold_feature_index and not enable :
291+ self .n_folds = 3
292+ self .__hide_show_feature_combo ()
293+
263294 def set_learner (self , learner , key ):
264295 """
265296 Set the input `learner` for `key`.
@@ -310,9 +341,10 @@ def set_train_data(self, data):
310341
311342 self .data = data
312343 self .closeContext ()
344+ self ._update_controls ()
313345 if data is not None :
314346 self ._update_class_selection ()
315- self .openContext (data .domain . class_var )
347+ self .openContext (data .domain )
316348 self ._invalidate ()
317349
318350 def set_test_data (self , data ):
@@ -372,6 +404,10 @@ def handleNewSignals(self):
372404
373405 def kfold_changed (self ):
374406 self .resampling = OWTestLearners .KFold
407+ self .__hide_show_feature_combo ()
408+ self ._param_changed ()
409+
410+ def fold_feature_changed (self ):
375411 self ._param_changed ()
376412
377413 def shuffle_split_changed (self ):
@@ -429,17 +465,22 @@ def update_progress(finished):
429465
430466 with self .progressBar ():
431467 try :
432- folds = self .NFolds [self .n_folds ]
433468 if self .resampling == OWTestLearners .KFold :
434- if len (self .data ) < folds :
435- self .Error .too_many_folds ()
436- return
437- warnings = []
438- results = Orange .evaluation .CrossValidation (
439- self .data , learners , k = folds ,
440- random_state = rstate , warnings = warnings , ** common_args )
441- if warnings :
442- self .warning (warnings [0 ])
469+ if self .n_folds == self .kfold_feature_index :
470+ results = Orange .evaluation .CrossValidationFeature (
471+ self .data , learners , self .fold_feature ,
472+ ** common_args )
473+ else :
474+ folds = self .NFolds [self .n_folds ]
475+ if len (self .data ) < folds :
476+ self .Error .too_many_folds ()
477+ return
478+ warnings = []
479+ results = Orange .evaluation .CrossValidation (
480+ self .data , learners , k = folds ,
481+ random_state = rstate , warnings = warnings , ** common_args )
482+ if warnings :
483+ self .warning (warnings [0 ])
443484 elif self .resampling == OWTestLearners .LeaveOneOut :
444485 results = Orange .evaluation .LeaveOneOut (
445486 self .data , learners , ** common_args )
0 commit comments