88from transformers import set_seed
99
1010from autointent import Dataset
11- from autointent .custom_types import ListOfGenericLabels , ListOfLabels , Split , ValidationScheme
11+ from autointent .configs import DataConfig
12+ from autointent .custom_types import FloatFromZeroToOne , ListOfGenericLabels , ListOfLabels , Split
1213
1314from ._stratification import split_dataset
1415
@@ -32,31 +33,27 @@ class DataHandler: # TODO rename to Validator
3233 def __init__ (
3334 self ,
3435 dataset : Dataset ,
35- scheme : ValidationScheme = "ho" ,
36- separate_nodes : bool = True ,
36+ config : DataConfig | None = None ,
3737 random_seed : int = 0 ,
38- n_folds : int = 3 ,
3938 ) -> None :
4039 """
4140 Initialize the data handler.
4241
4342 :param dataset: Training dataset.
4443 :param random_seed: Seed for random number generation.
45- :param separate_nodes: Perform or not splitting of train (default to split to be used in scoring and
46- threshold search).
44+ :param config: config
4745 """
4846 set_seed (random_seed )
4947 self .random_seed = random_seed
5048
5149 self .dataset = dataset
50+ self .config = config if config is not None else DataConfig ()
5251
5352 self .n_classes = self .dataset .n_classes
54- self .scheme = scheme
55- self .n_folds = n_folds
5653
57- if scheme == "ho" :
58- self ._split_ho (separate_nodes )
59- elif scheme == "cv" :
54+ if self . config . scheme == "ho" :
55+ self ._split_ho (self . config . separation_ratio , self . config . validation_size )
56+ elif self . config . scheme == "cv" :
6057 self ._split_cv ()
6158
6259 self .regex_patterns = [
@@ -120,7 +117,7 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
120117 return cast (ListOfGenericLabels , self .dataset [split ][self .dataset .label_feature ])
121118
122119 def train_labels_folded (self ) -> list [ListOfGenericLabels ]:
123- return [self .train_labels (j ) for j in range (self .n_folds )]
120+ return [self .train_labels (j ) for j in range (self .config . n_folds )]
124121
125122 def validation_utterances (self , idx : int | None = None ) -> list [str ]:
126123 """
@@ -179,14 +176,14 @@ def test_labels(self) -> ListOfGenericLabels:
179176 return cast (ListOfGenericLabels , self .dataset [Split .TEST ][self .dataset .label_feature ])
180177
181178 def validation_iterator (self ) -> Generator [tuple [list [str ], ListOfLabels , list [str ], ListOfLabels ]]:
182- if self .scheme == "ho" :
179+ if self .config . scheme == "ho" :
183180 msg = "Cannot call cross-validation on hold-out DataHandler"
184181 raise RuntimeError (msg )
185182
186- for j in range (self .n_folds ):
183+ for j in range (self .config . n_folds ):
187184 val_utterances = self .train_utterances (j )
188185 val_labels = self .train_labels (j )
189- train_folds = [i for i in range (self .n_folds ) if i != j ]
186+ train_folds = [i for i in range (self .config . n_folds ) if i != j ]
190187 train_utterances = [ut for i_fold in train_folds for ut in self .train_utterances (i_fold )]
191188 train_labels = [lab for i_fold in train_folds for lab in self .train_labels (i_fold )]
192189
@@ -195,14 +192,14 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
195192 train_labels = [lab for lab in train_labels if lab is not None ]
196193 yield train_utterances , train_labels , val_utterances , val_labels # type: ignore[misc]
197194
198- def _split_ho (self , separate_nodes : bool ) -> None :
195+ def _split_ho (self , separation_ratio : FloatFromZeroToOne | None , validation_size : FloatFromZeroToOne ) -> None :
199196 has_validation_split = any (split .startswith (Split .VALIDATION ) for split in self .dataset )
200197
201- if separate_nodes and Split .TRAIN in self .dataset :
202- self ._split_train ()
198+ if separation_ratio is not None and Split .TRAIN in self .dataset :
199+ self ._split_train (separation_ratio )
203200
204201 if not has_validation_split :
205- self ._split_validation_from_train ()
202+ self ._split_validation_from_train (validation_size )
206203
207204 for split in self .dataset :
208205 n_classes_in_split = self .dataset .get_n_classes (split )
@@ -212,7 +209,7 @@ def _split_ho(self, separate_nodes: bool) -> None:
212209 )
213210 raise ValueError (message )
214211
215- def _split_train (self ) -> None :
212+ def _split_train (self , ratio : FloatFromZeroToOne ) -> None :
216213 """
217214 Split on two sets.
218215
@@ -221,40 +218,32 @@ def _split_train(self) -> None:
221218 self .dataset [f"{ Split .TRAIN } _0" ], self .dataset [f"{ Split .TRAIN } _1" ] = split_dataset (
222219 self .dataset ,
223220 split = Split .TRAIN ,
224- test_size = 0.5 ,
221+ test_size = ratio ,
225222 random_seed = self .random_seed ,
226223 allow_oos_in_train = False , # only train data for decision node should contain OOS
227224 )
228225 self .dataset .pop (Split .TRAIN )
229226
230227 def _split_cv (self ) -> None :
231- extra_splits = [split_name for split_name in self .dataset if split_name not in [Split .TRAIN , Split .TEST ]]
232- if extra_splits :
233- self .dataset [Split .TRAIN ] = concatenate_datasets (
234- [self .dataset .pop (split_name ) for split_name in extra_splits ]
235- )
236-
237- if Split .TEST not in self .dataset :
238- self .dataset [Split .TRAIN ], self .dataset [Split .TEST ] = split_dataset (
239- self .dataset , split = Split .TRAIN , test_size = 0.2 , random_seed = self .random_seed , allow_oos_in_train = True
240- )
228+ extra_splits = [split_name for split_name in self .dataset if split_name != Split .TEST ]
229+ self .dataset [Split .TRAIN ] = concatenate_datasets ([self .dataset .pop (split_name ) for split_name in extra_splits ])
241230
242- for j in range (self .n_folds - 1 ):
231+ for j in range (self .config . n_folds - 1 ):
243232 self .dataset [Split .TRAIN ], self .dataset [f"{ Split .TRAIN } _{ j } " ] = split_dataset (
244233 self .dataset ,
245234 split = Split .TRAIN ,
246- test_size = 1 / (self .n_folds - j ),
235+ test_size = 1 / (self .config . n_folds - j ),
247236 random_seed = self .random_seed ,
248237 allow_oos_in_train = True ,
249238 )
250- self .dataset [f"{ Split .TRAIN } _{ self .n_folds - 1 } " ] = self .dataset .pop (Split .TRAIN )
239+ self .dataset [f"{ Split .TRAIN } _{ self .config . n_folds - 1 } " ] = self .dataset .pop (Split .TRAIN )
251240
252- def _split_validation_from_train (self ) -> None :
241+ def _split_validation_from_train (self , size : float ) -> None :
253242 if Split .TRAIN in self .dataset :
254243 self .dataset [Split .TRAIN ], self .dataset [Split .VALIDATION ] = split_dataset (
255244 self .dataset ,
256245 split = Split .TRAIN ,
257- test_size = 0.2 ,
246+ test_size = size ,
258247 random_seed = self .random_seed ,
259248 allow_oos_in_train = True ,
260249 )
@@ -263,13 +252,13 @@ def _split_validation_from_train(self) -> None:
263252 self .dataset [f"{ Split .TRAIN } _{ idx } " ], self .dataset [f"{ Split .VALIDATION } _{ idx } " ] = split_dataset (
264253 self .dataset ,
265254 split = f"{ Split .TRAIN } _{ idx } " ,
266- test_size = 0.2 ,
255+ test_size = size ,
267256 random_seed = self .random_seed ,
268257 allow_oos_in_train = idx == 1 , # for decision node it's ok to have oos in train
269258 )
270259
271260 def prepare_for_refit (self ) -> None :
272- if self .scheme == "ho" :
261+ if self .config . scheme == "ho" :
273262 return
274263
275264 train_folds = [split_name for split_name in self .dataset if split_name .startswith (Split .TRAIN )]
@@ -278,7 +267,7 @@ def prepare_for_refit(self) -> None:
278267 self .dataset [f"{ Split .TRAIN } _0" ], self .dataset [f"{ Split .TRAIN } _1" ] = split_dataset (
279268 self .dataset ,
280269 split = Split .TRAIN ,
281- test_size = 0.5 ,
270+ test_size = self . config . separation_ratio or 0.5 ,
282271 random_seed = self .random_seed ,
283272 allow_oos_in_train = False ,
284273 )
0 commit comments