@@ -234,21 +234,24 @@ def run(self, config, instance=None,
234234
235235 def get_splitter (self , D ):
236236 y = D .data ['Y_train' ].ravel ()
237-
237+ train_size = 0.67
238+ if self .resampling_strategy_args :
239+ train_size = self .resampling_strategy_args .get ('train_size' , train_size )
240+ test_size = 1 - train_size
238241 if D .info ['task' ] in CLASSIFICATION_TASKS and \
239242 D .info ['task' ] != MULTILABEL_CLASSIFICATION :
240243
241244 if self .resampling_strategy in ['holdout' ,
242245 'holdout-iterative-fit' ]:
243246 try :
244- cv = StratifiedShuffleSplit (n_splits = 1 , train_size = 0.67 ,
245- test_size = 0.33 , random_state = 1 )
247+ cv = StratifiedShuffleSplit (n_splits = 1 , train_size = train_size ,
248+ test_size = test_size , random_state = 1 )
246249 test_cv = copy .deepcopy (cv )
247250 next (test_cv .split (y , y ))
248251 except ValueError as e :
249252 if 'The least populated class in y has only' in e .args [0 ]:
250- cv = ShuffleSplit (n_splits = 1 , train_size = 0.67 ,
251- test_size = 0.33 , random_state = 1 )
253+ cv = ShuffleSplit (n_splits = 1 , train_size = train_size ,
254+ test_size = test_size , random_state = 1 )
252255 else :
253256 raise
254257
@@ -261,8 +264,8 @@ def get_splitter(self, D):
261264 else :
262265 if self .resampling_strategy in ['holdout' ,
263266 'holdout-iterative-fit' ]:
264- cv = ShuffleSplit (n_splits = 1 , train_size = 0.67 ,
265- test_size = 0.33 , random_state = 1 )
267+ cv = ShuffleSplit (n_splits = 1 , train_size = train_size ,
268+ test_size = test_size , random_state = 1 )
266269 elif self .resampling_strategy in ['cv' , 'partial-cv' ,
267270 'partial-cv-iterative-fit' ]:
268271 cv = KFold (n_splits = self .resampling_strategy_args ['folds' ],
0 commit comments