@@ -33,7 +33,7 @@ def __init__(
3333 self ,
3434 dataset : Dataset ,
3535 scheme : ValidationScheme = "ho" ,
36- split_train : bool = True ,
36+ separate_nodes : bool = True ,
3737 random_seed : int = 0 ,
3838 n_folds : int = 3 ,
3939 ) -> None :
@@ -42,7 +42,7 @@ def __init__(
4242
4343 :param dataset: Training dataset.
4444 :param random_seed: Seed for random number generation.
45- :param split_train : Perform or not splitting of train (default to split to be used in scoring and
45+ :param separate_nodes : Perform or not splitting of train (default to split to be used in scoring and
4646 threshold search).
4747 """
4848 set_seed (random_seed )
@@ -55,7 +55,7 @@ def __init__(
5555 self .n_folds = n_folds
5656
5757 if scheme == "ho" :
58- self ._split_ho (split_train )
58+ self ._split_ho (separate_nodes )
5959 elif scheme == "cv" :
6060 self ._split_cv ()
6161
@@ -82,6 +82,15 @@ def multilabel(self) -> bool:
8282 """
8383 return self .dataset .multilabel
8484
85+ def _choose_split (self , split_name : str , idx : int | None = None ) -> str :
86+ if idx is not None :
87+ split = f"{ split_name } _{ idx } "
88+ if split not in self .dataset :
89+ split = split_name
90+ else :
91+ split = split_name
92+ return split
93+
8594 def train_utterances (self , idx : int | None = None ) -> list [str ]:
8695 """
8796 Retrieve training utterances from the dataset.
@@ -93,7 +102,7 @@ def train_utterances(self, idx: int | None = None) -> list[str]:
93102 :param idx: Optional index for a specific training split.
94103 :return: List of training utterances.
95104 """
96- split = f" { Split .TRAIN } _ { idx } " if idx is not None else Split . TRAIN
105+ split = self . _choose_split ( Split .TRAIN , idx )
97106 return cast (list [str ], self .dataset [split ][self .dataset .utterance_feature ])
98107
99108 def train_labels (self , idx : int | None = None ) -> ListOfGenericLabels :
@@ -107,7 +116,7 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
107116 :param idx: Optional index for a specific training split.
108117 :return: List of training labels.
109118 """
110- split = f" { Split .TRAIN } _ { idx } " if idx is not None else Split . TRAIN
119+ split = self . _choose_split ( Split .TRAIN , idx )
111120 return cast (ListOfGenericLabels , self .dataset [split ][self .dataset .label_feature ])
112121
113122 def train_labels_folded (self ) -> list [ListOfGenericLabels ]:
@@ -124,7 +133,7 @@ def validation_utterances(self, idx: int | None = None) -> list[str]:
124133 :param idx: Optional index for a specific validation split.
125134 :return: List of validation utterances.
126135 """
127- split = f" { Split .VALIDATION } _ { idx } " if idx is not None else Split . VALIDATION
136+ split = self . _choose_split ( Split .VALIDATION , idx )
128137 return cast (list [str ], self .dataset [split ][self .dataset .utterance_feature ])
129138
130139 def validation_labels (self , idx : int | None = None ) -> ListOfGenericLabels :
@@ -138,10 +147,10 @@ def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels:
138147 :param idx: Optional index for a specific validation split.
139148 :return: List of validation labels.
140149 """
141- split = f" { Split .VALIDATION } _ { idx } " if idx is not None else Split . VALIDATION
150+ split = self . _choose_split ( Split .VALIDATION , idx )
142151 return cast (ListOfGenericLabels , self .dataset [split ][self .dataset .label_feature ])
143152
144- def test_utterances (self , idx : int | None = None ) -> list [str ]:
153+ def test_utterances (self ) -> list [str ]:
145154 """
146155 Retrieve test utterances from the dataset.
147156
@@ -152,10 +161,9 @@ def test_utterances(self, idx: int | None = None) -> list[str]:
152161 :param idx: Optional index for a specific test split.
153162 :return: List of test utterances.
154163 """
155- split = f"{ Split .TEST } _{ idx } " if idx is not None else Split .TEST
156- return cast (list [str ], self .dataset [split ][self .dataset .utterance_feature ])
164+ return cast (list [str ], self .dataset [Split .TEST ][self .dataset .utterance_feature ])
157165
158- def test_labels (self , idx : int | None = None ) -> ListOfGenericLabels :
166+ def test_labels (self ) -> ListOfGenericLabels :
159167 """
160168 Retrieve test labels from the dataset.
161169
@@ -166,8 +174,7 @@ def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
166174 :param idx: Optional index for a specific test split.
167175 :return: List of test labels.
168176 """
169- split = f"{ Split .TEST } _{ idx } " if idx is not None else Split .TEST
170- return cast (ListOfGenericLabels , self .dataset [split ][self .dataset .label_feature ])
177+ return cast (ListOfGenericLabels , self .dataset [Split .TEST ][self .dataset .label_feature ])
171178
172179 def validation_iterator (self ) -> Generator [tuple [list [str ], ListOfLabels , list [str ], ListOfLabels ]]:
173180 if self .scheme == "ho" :
@@ -186,27 +193,20 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
186193 train_labels = [lab for lab in train_labels if lab is not None ]
187194 yield train_utterances , train_labels , val_utterances , val_labels # type: ignore[misc]
188195
189- def _split_ho (self , split_train : bool ) -> None :
196+ def _split_ho (self , separate_nodes : bool ) -> None :
190197 has_validation_split = any (split .startswith (Split .VALIDATION ) for split in self .dataset )
191198
192- if split_train and Split .TRAIN in self .dataset :
199+ if separate_nodes and Split .TRAIN in self .dataset :
193200 self ._split_train ()
194201
195- if Split .TEST not in self .dataset :
196- test_size = 0.1 if has_validation_split else 0.2
197- self ._split_test (test_size )
198-
199202 if not has_validation_split :
200203 self ._split_validation_from_train ()
201- elif Split .VALIDATION in self .dataset :
202- self ._split_validation ()
203204
204205 for split in self .dataset :
205- n_classes_split = self .dataset .get_n_classes (split )
206- if n_classes_split != self .n_classes :
206+ n_classes_in_split = self .dataset .get_n_classes (split )
207+ if n_classes_in_split != self .n_classes :
207208 message = (
208- f"Number of classes in split '{ split } ' doesn't match initial number of classes "
209- f"({ n_classes_split } != { self .n_classes } )"
209+ f"{ n_classes_in_split = } for '{ split = } ' doesn't match initial number of classes ({ self .n_classes } )"
210210 )
211211 raise ValueError (message )
212212
@@ -225,30 +225,6 @@ def _split_train(self) -> None:
225225 )
226226 self .dataset .pop (Split .TRAIN )
227227
228- def _split_validation (self ) -> None :
229- """
230- Split on two sets.
231-
232- One is for scoring node optimizaton, one is for decision node.
233- """
234- self .dataset [f"{ Split .VALIDATION } _0" ], self .dataset [f"{ Split .VALIDATION } _1" ] = split_dataset (
235- self .dataset ,
236- split = Split .VALIDATION ,
237- test_size = 0.5 ,
238- random_seed = self .random_seed ,
239- allow_oos_in_train = False , # only val data for decision node should contain OOS
240- )
241- self .dataset .pop (Split .VALIDATION )
242-
243- def _split_validation_from_test (self ) -> None :
244- self .dataset [Split .TEST ], self .dataset [Split .VALIDATION ] = split_dataset (
245- self .dataset ,
246- split = Split .TEST ,
247- test_size = 0.5 ,
248- random_seed = self .random_seed ,
249- allow_oos_in_train = True , # both test and validation splits can contain OOS
250- )
251-
252228 def _split_cv (self ) -> None :
253229 extra_splits = [split_name for split_name in self .dataset if split_name not in [Split .TRAIN , Split .TEST ]]
254230 if extra_splits :
@@ -290,27 +266,6 @@ def _split_validation_from_train(self) -> None:
290266 allow_oos_in_train = idx == 1 , # for decision node it's ok to have oos in train
291267 )
292268
293- def _split_test (self , test_size : float ) -> None :
294- """Obtain test set from train."""
295- self .dataset [f"{ Split .TRAIN } _0" ], self .dataset [f"{ Split .TEST } _0" ] = split_dataset (
296- self .dataset ,
297- split = f"{ Split .TRAIN } _0" ,
298- test_size = test_size ,
299- random_seed = self .random_seed ,
300- )
301- self .dataset [f"{ Split .TRAIN } _1" ], self .dataset [f"{ Split .TEST } _1" ] = split_dataset (
302- self .dataset ,
303- split = f"{ Split .TRAIN } _1" ,
304- test_size = test_size ,
305- random_seed = self .random_seed ,
306- allow_oos_in_train = True ,
307- )
308- self .dataset [Split .TEST ] = concatenate_datasets (
309- [self .dataset [f"{ Split .TEST } _0" ], self .dataset [f"{ Split .TEST } _1" ]],
310- )
311- self .dataset .pop (f"{ Split .TEST } _0" )
312- self .dataset .pop (f"{ Split .TEST } _1" )
313-
314269 def prepare_for_refit (self ) -> None :
315270 if self .scheme == "ho" :
316271 return
0 commit comments