Skip to content

Commit 4530fc6

Browse files
committed
FIX holdout with only a single instance for a class
1 parent 91017ed commit 4530fc6

File tree

2 files changed

+93
-29
lines changed

2 files changed

+93
-29
lines changed

autosklearn/evaluation/__init__.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -126,36 +126,9 @@ def run(self, config, instance=None,
126126
include=self.include,
127127
exclude=self.exclude,
128128
disable_file_output=self.disable_file_output)
129-
if self.resampling_strategy != 'test':
130-
if D.info['task'] in CLASSIFICATION_TASKS and \
131-
D.info['task'] != MULTILABEL_CLASSIFICATION:
132-
y = D.data['Y_train'].ravel()
133-
if self.resampling_strategy in ['holdout',
134-
'holdout-iterative-fit']:
135-
cv = StratifiedShuffleSplit(y=y, n_iter=1, train_size=0.67,
136-
test_size=0.33, random_state=1)
137-
elif self.resampling_strategy in ['cv', 'partial-cv',
138-
'partial-cv-iterative-fit']:
139-
cv = StratifiedKFold(y=y,
140-
n_folds=self.resampling_strategy_args[
141-
'folds'],
142-
shuffle=True, random_state=1)
143-
else:
144-
raise ValueError(self.resampling_strategy)
145-
else:
146-
n = D.data['Y_train'].shape[0]
147-
if self.resampling_strategy in ['holdout',
148-
'holdout-iterative-fit']:
149-
cv = ShuffleSplit(n=n, n_iter=1, train_size=0.67,
150-
test_size=0.33, random_state=1)
151-
elif self.resampling_strategy in ['cv', 'partial-cv',
152-
'partial-cv-iterative-fit']:
153-
cv = KFold(n=n,
154-
n_folds=self.resampling_strategy_args['folds'],
155-
shuffle=True, random_state=1)
156-
else:
157-
raise ValueError(self.resampling_strategy)
158129

130+
if self.resampling_strategy != 'test':
131+
cv = self.get_splitter(D)
159132
obj_kwargs['cv'] = cv
160133
if instance is not None:
161134
obj_kwargs['instance'] = instance
@@ -208,3 +181,43 @@ def run(self, config, instance=None,
208181
self.num_run += 1
209182
return status, cost, runtime, additional_run_info
210183

184+
def get_splitter(self, D):
185+
y = D.data['Y_train'].ravel()
186+
n = D.data['Y_train'].shape[0]
187+
if D.info['task'] in CLASSIFICATION_TASKS and \
188+
D.info['task'] != MULTILABEL_CLASSIFICATION:
189+
190+
if self.resampling_strategy in ['holdout',
191+
'holdout-iterative-fit']:
192+
try:
193+
cv = StratifiedShuffleSplit(y=y, n_iter=1, train_size=0.67,
194+
test_size=0.33, random_state=1)
195+
except ValueError as e:
196+
if 'The least populated class in y has only' in e.args[0]:
197+
cv = ShuffleSplit(n=n, n_iter=1, train_size=0.67,
198+
test_size=0.33, random_state=1)
199+
else:
200+
raise
201+
202+
elif self.resampling_strategy in ['cv', 'partial-cv',
203+
'partial-cv-iterative-fit']:
204+
cv = StratifiedKFold(y=y,
205+
n_folds=self.resampling_strategy_args[
206+
'folds'],
207+
shuffle=True, random_state=1)
208+
else:
209+
raise ValueError(self.resampling_strategy)
210+
else:
211+
if self.resampling_strategy in ['holdout',
212+
'holdout-iterative-fit']:
213+
cv = ShuffleSplit(n=n, n_iter=1, train_size=0.67,
214+
test_size=0.33, random_state=1)
215+
elif self.resampling_strategy in ['cv', 'partial-cv',
216+
'partial-cv-iterative-fit']:
217+
cv = KFold(n=n,
218+
n_folds=self.resampling_strategy_args['folds'],
219+
shuffle=True, random_state=1)
220+
else:
221+
raise ValueError(self.resampling_strategy)
222+
return cv
223+

test/test_evaluation/test_evaluation.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
import pynisher
1515
from smac.tae.execute_ta_run import StatusType
1616
from smac.stats.stats import Stats
17+
import sklearn.cross_validation
1718

1819
from evaluation_util import get_multiclass_classification_datamanager
20+
from autosklearn.constants import *
1921
from autosklearn.evaluation import ExecuteTaFuncWithQueue
22+
from autosklearn.data.abstract_data_manager import AbstractDataManager
2023

2124

2225
def safe_eval_success_mock(*args, **kwargs):
@@ -189,3 +192,51 @@ def side_effect(*args, **kwargs):
189192
instance_specific='subsample=30')
190193
self.assertEqual(info[0], StatusType.SUCCESS)
191194
self.assertEqual(info[-1], 30)
195+
196+
def test_get_splitter(self):
197+
ta_args = dict(backend=BackendMock(), autosklearn_seed=1,
198+
logger=self.logger, stats=self.stats, memory_limit=3072)
199+
D = unittest.mock.Mock(spec=AbstractDataManager)
200+
D.data = dict(Y_train=np.array([0, 0, 0, 1, 1, 1]))
201+
D.info = dict(task=BINARY_CLASSIFICATION)
202+
203+
# holdout, binary classification
204+
ta = ExecuteTaFuncWithQueue(resampling_strategy='holdout', **ta_args)
205+
cv = ta.get_splitter(D)
206+
self.assertIsInstance(cv,
207+
sklearn.cross_validation.StratifiedShuffleSplit)
208+
209+
# holdout, binary classification, fallback to shuffle split
210+
D.data['Y_train'] = np.array([0, 0, 0, 1, 1, 1, 2])
211+
ta = ExecuteTaFuncWithQueue(resampling_strategy='holdout', **ta_args)
212+
cv = ta.get_splitter(D)
213+
self.assertIsInstance(cv, sklearn.cross_validation.ShuffleSplit)
214+
215+
# cv, binary classification
216+
D.data['Y_train'] = np.array([0, 0, 0, 1, 1, 1])
217+
ta = ExecuteTaFuncWithQueue(resampling_strategy='cv', folds=5,
218+
**ta_args)
219+
cv = ta.get_splitter(D)
220+
self.assertIsInstance(cv, sklearn.cross_validation.StratifiedKFold)
221+
222+
# cv, binary classification, no fallback anticipated
223+
D.data['Y_train'] = np.array([0, 0, 0, 1, 1, 1, 2])
224+
ta = ExecuteTaFuncWithQueue(resampling_strategy='cv', folds=5,
225+
**ta_args)
226+
cv = ta.get_splitter(D)
227+
self.assertIsInstance(cv, sklearn.cross_validation.StratifiedKFold)
228+
229+
# regression, shuffle split
230+
D.data['Y_train'] = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
231+
D.info['task'] = REGRESSION
232+
ta = ExecuteTaFuncWithQueue(resampling_strategy='holdout', **ta_args)
233+
cv = ta.get_splitter(D)
234+
self.assertIsInstance(cv, sklearn.cross_validation.ShuffleSplit)
235+
236+
# regression cv, KFold
237+
D.data['Y_train'] = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
238+
D.info['task'] = REGRESSION
239+
ta = ExecuteTaFuncWithQueue(resampling_strategy='cv', folds=5,
240+
**ta_args)
241+
cv = ta.get_splitter(D)
242+
self.assertIsInstance(cv, sklearn.cross_validation.KFold)

0 commit comments

Comments
 (0)