@@ -46,6 +46,7 @@ def __init__(
4646 threshold search).
4747 """
4848 set_seed (random_seed )
49+ self .random_seed = random_seed
4950
5051 self .dataset = dataset
5152
@@ -54,9 +55,9 @@ def __init__(
5455 self .n_folds = n_folds
5556
5657 if scheme == "ho" :
57- self ._split_ho (random_seed , split_train )
58+ self ._split_ho (split_train )
5859 elif scheme == "cv" :
59- self ._split_cv (random_seed )
60+ self ._split_cv ()
6061
6162 self .regexp_patterns = [
6263 RegexPatterns (
@@ -185,20 +186,20 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
185186 train_labels = [lab for lab in train_labels if lab is not None ]
186187 yield train_utterances , train_labels , val_utterances , val_labels # type: ignore[misc]
187188
188- def _split_ho (self , random_seed : int , split_train : bool ) -> None :
189+ def _split_ho (self , split_train : bool ) -> None :
189190 has_validation_split = any (split .startswith (Split .VALIDATION ) for split in self .dataset )
190191
191192 if split_train and Split .TRAIN in self .dataset :
192- self ._split_train (random_seed )
193+ self ._split_train ()
193194
194195 if Split .TEST not in self .dataset :
195196 test_size = 0.1 if has_validation_split else 0.2
196- self ._split_test (test_size , random_seed )
197+ self ._split_test (test_size )
197198
198199 if not has_validation_split :
199- self ._split_validation_from_train (random_seed )
200+ self ._split_validation_from_train ()
200201 elif Split .VALIDATION in self .dataset :
201- self ._split_validation (random_seed )
202+ self ._split_validation ()
202203
203204 for split in self .dataset :
204205 n_classes_split = self .dataset .get_n_classes (split )
@@ -209,7 +210,7 @@ def _split_ho(self, random_seed: int, split_train: bool) -> None:
209210 )
210211 raise ValueError (message )
211212
212- def _split_train (self , random_seed : int ) -> None :
213+ def _split_train (self ) -> None :
213214 """
214215 Split on two sets.
215216
@@ -219,12 +220,12 @@ def _split_train(self, random_seed: int) -> None:
219220 self .dataset ,
220221 split = Split .TRAIN ,
221222 test_size = 0.5 ,
222- random_seed = random_seed ,
223+ random_seed = self . random_seed ,
223224 allow_oos_in_train = False , # only train data for decision node should contain OOS
224225 )
225226 self .dataset .pop (Split .TRAIN )
226227
227- def _split_validation (self , random_seed : int ) -> None :
228+ def _split_validation (self ) -> None :
228229 """
229230 Split on two sets.
230231
@@ -234,21 +235,21 @@ def _split_validation(self, random_seed: int) -> None:
234235 self .dataset ,
235236 split = Split .VALIDATION ,
236237 test_size = 0.5 ,
237- random_seed = random_seed ,
238+ random_seed = self . random_seed ,
238239 allow_oos_in_train = False , # only val data for decision node should contain OOS
239240 )
240241 self .dataset .pop (Split .VALIDATION )
241242
242- def _split_validation_from_test (self , random_seed : int ) -> None :
243+ def _split_validation_from_test (self ) -> None :
243244 self .dataset [Split .TEST ], self .dataset [Split .VALIDATION ] = split_dataset (
244245 self .dataset ,
245246 split = Split .TEST ,
246247 test_size = 0.5 ,
247- random_seed = random_seed ,
248+ random_seed = self . random_seed ,
248249 allow_oos_in_train = True , # both test and validation splits can contain OOS
249250 )
250251
251- def _split_cv (self , random_seed : int ) -> None :
252+ def _split_cv (self ) -> None :
252253 extra_splits = [split_name for split_name in self .dataset if split_name not in [Split .TRAIN , Split .TEST ]]
253254 if extra_splits :
254255 self .dataset [Split .TRAIN ] = concatenate_datasets (
@@ -257,26 +258,26 @@ def _split_cv(self, random_seed: int) -> None:
257258
258259 if Split .TEST not in self .dataset :
259260 self .dataset [Split .TRAIN ], self .dataset [Split .TEST ] = split_dataset (
260- self .dataset , split = Split .TRAIN , test_size = 0.2 , random_seed = random_seed , allow_oos_in_train = True
261+ self .dataset , split = Split .TRAIN , test_size = 0.2 , random_seed = self . random_seed , allow_oos_in_train = True
261262 )
262263
263264 for j in range (self .n_folds - 1 ):
264265 self .dataset [Split .TRAIN ], self .dataset [f"{ Split .TRAIN } _{ j } " ] = split_dataset (
265266 self .dataset ,
266267 split = Split .TRAIN ,
267268 test_size = 1 / (self .n_folds - j ),
268- random_seed = random_seed ,
269+ random_seed = self . random_seed ,
269270 allow_oos_in_train = True ,
270271 )
271272 self .dataset [f"{ Split .TRAIN } _{ self .n_folds - 1 } " ] = self .dataset .pop (Split .TRAIN )
272273
273- def _split_validation_from_train (self , random_seed : int ) -> None :
274+ def _split_validation_from_train (self ) -> None :
274275 if Split .TRAIN in self .dataset :
275276 self .dataset [Split .TRAIN ], self .dataset [Split .VALIDATION ] = split_dataset (
276277 self .dataset ,
277278 split = Split .TRAIN ,
278279 test_size = 0.2 ,
279- random_seed = random_seed ,
280+ random_seed = self . random_seed ,
280281 allow_oos_in_train = True ,
281282 )
282283 else :
@@ -285,27 +286,46 @@ def _split_validation_from_train(self, random_seed: int) -> None:
285286 self .dataset ,
286287 split = f"{ Split .TRAIN } _{ idx } " ,
287288 test_size = 0.2 ,
288- random_seed = random_seed ,
289+ random_seed = self . random_seed ,
289290 allow_oos_in_train = idx == 1 , # for decision node it's ok to have oos in train
290291 )
291292
292- def _split_test (self , test_size : float , random_seed : int ) -> None :
293+ def _split_test (self , test_size : float ) -> None :
293294 """Obtain test set from train."""
294295 self .dataset [f"{ Split .TRAIN } _0" ], self .dataset [f"{ Split .TEST } _0" ] = split_dataset (
295296 self .dataset ,
296297 split = f"{ Split .TRAIN } _0" ,
297298 test_size = test_size ,
298- random_seed = random_seed ,
299+ random_seed = self . random_seed ,
299300 )
300301 self .dataset [f"{ Split .TRAIN } _1" ], self .dataset [f"{ Split .TEST } _1" ] = split_dataset (
301302 self .dataset ,
302303 split = f"{ Split .TRAIN } _1" ,
303304 test_size = test_size ,
304- random_seed = random_seed ,
305+ random_seed = self . random_seed ,
305306 allow_oos_in_train = True ,
306307 )
307308 self .dataset [Split .TEST ] = concatenate_datasets (
308309 [self .dataset [f"{ Split .TEST } _0" ], self .dataset [f"{ Split .TEST } _1" ]],
309310 )
310311 self .dataset .pop (f"{ Split .TEST } _0" )
311312 self .dataset .pop (f"{ Split .TEST } _1" )
313+
314+ def prepare_for_refit (self ) -> None :
315+ if self .scheme == "ho" :
316+ return
317+
318+ train_folds = [split_name for split_name in self .dataset if split_name .startswith ("train" )]
319+ self .dataset [Split .TRAIN ] = concatenate_datasets ([self .dataset [name ] for name in train_folds ])
320+ for name in train_folds :
321+ self .dataset .pop (name )
322+
323+ self .dataset [f"{ Split .TRAIN } _0" ], self .dataset [f"{ Split .TRAIN } _1" ] = split_dataset (
324+ self .dataset ,
325+ split = Split .TRAIN ,
326+ test_size = 0.5 ,
327+ random_seed = self .random_seed ,
328+ allow_oos_in_train = False ,
329+ )
330+
331+ self .dataset .pop (Split .TRAIN )
0 commit comments