@@ -50,40 +50,7 @@ def __init__(
5050
5151 self .n_classes = self .dataset .n_classes
5252
53- if Split .TEST not in self .dataset :
54- self .dataset [Split .TRAIN ], self .dataset [Split .TEST ] = split_dataset (
55- self .dataset ,
56- split = Split .TRAIN ,
57- test_size = 0.2 ,
58- random_seed = random_seed ,
59- )
60-
61- self .dataset [f"{ Split .TRAIN } _0" ], self .dataset [f"{ Split .TRAIN } _1" ] = split_dataset (
62- self .dataset ,
63- split = Split .TRAIN ,
64- test_size = 0.5 ,
65- random_seed = random_seed ,
66- )
67- self .dataset .pop (Split .TRAIN )
68-
69- for idx in range (2 ):
70- self .dataset [f"{ Split .TRAIN } _{ idx } " ], self .dataset [f"{ Split .VALIDATION } _{ idx } " ] = split_dataset (
71- self .dataset ,
72- split = f"{ Split .TRAIN } _{ idx } " ,
73- test_size = 0.2 ,
74- random_seed = random_seed ,
75- )
76-
77- for split in self .dataset :
78- if split == Split .OOS :
79- continue
80- n_classes_split = self .dataset .get_n_classes (split )
81- if n_classes_split != self .n_classes :
82- message = (
83- f"Number of classes in split '{ split } ' doesn't match initial number of classes "
84- f"({ n_classes_split } != { self .n_classes } )"
85- )
86- raise ValueError (message )
53+ self ._split (random_seed )
8754
8855 self .regexp_patterns = [
8956 RegexPatterns (
@@ -162,14 +129,15 @@ def test_labels(self, idx: int | None = None) -> list[LabelType]:
162129 split = f"{ Split .TEST } _{ idx } " if idx is not None else Split .TEST
163130 return cast (list [LabelType ], self .dataset [split ][self .dataset .label_feature ])
164131
165- def oos_utterances (self ) -> list [str ]:
132+ def oos_utterances (self , idx : int | None = None ) -> list [str ]:
166133 """
167134 Get the out-of-scope utterances.
168135
169136 :return: List of out-of-scope utterances if available, otherwise an empty list.
170137 """
171138 if self .has_oos_samples ():
172- return cast (list [str ], self .dataset [Split .OOS ][self .dataset .utterance_feature ])
139+ split = f"{ Split .OOS } _{ idx } " if idx is not None else Split .OOS
140+ return cast (list [str ], self .dataset [split ][self .dataset .utterance_feature ])
173141 return []
174142
175143 def has_oos_samples (self ) -> bool :
@@ -178,7 +146,7 @@ def has_oos_samples(self) -> bool:
178146
179147 :return: True if there are out-of-scope samples.
180148 """
181- return Split .OOS in self .dataset
149+ return any ( split . startswith ( Split .OOS ) for split in self .dataset )
182150
183151 def dump (self ) -> dict [str , list [dict [str , Any ]]]:
184152 """
@@ -187,3 +155,60 @@ def dump(self) -> dict[str, list[dict[str, Any]]]:
187155 :return: Dataset dump.
188156 """
189157 return self .dataset .dump ()
158+
159+ def _split (self , random_seed : int ) -> None :
160+ if Split .TEST not in self .dataset :
161+ self .dataset [Split .TRAIN ], self .dataset [Split .TEST ] = split_dataset (
162+ self .dataset ,
163+ split = Split .TRAIN ,
164+ test_size = 0.2 ,
165+ random_seed = random_seed ,
166+ )
167+
168+ self .dataset [f"{ Split .TRAIN } _0" ], self .dataset [f"{ Split .TRAIN } _1" ] = split_dataset (
169+ self .dataset ,
170+ split = Split .TRAIN ,
171+ test_size = 0.5 ,
172+ random_seed = random_seed ,
173+ )
174+ self .dataset .pop (Split .TRAIN )
175+
176+ for idx in range (2 ):
177+ self .dataset [f"{ Split .TRAIN } _{ idx } " ], self .dataset [f"{ Split .VALIDATION } _{ idx } " ] = split_dataset (
178+ self .dataset ,
179+ split = f"{ Split .TRAIN } _{ idx } " ,
180+ test_size = 0.2 ,
181+ random_seed = random_seed ,
182+ )
183+
184+ if self .has_oos_samples ():
185+ self .dataset [f"{ Split .OOS } _0" ], self .dataset [f"{ Split .OOS } _1" ] = (
186+ self .dataset [Split .OOS ]
187+ .train_test_split (
188+ test_size = 0.2 ,
189+ shuffle = True ,
190+ seed = random_seed ,
191+ )
192+ .values ()
193+ )
194+ self .dataset [f"{ Split .OOS } _1" ], self .dataset [f"{ Split .OOS } _2" ] = (
195+ self .dataset [f"{ Split .OOS } _1" ]
196+ .train_test_split (
197+ test_size = 0.5 ,
198+ shuffle = True ,
199+ seed = random_seed ,
200+ )
201+ .values ()
202+ )
203+ self .dataset .pop (Split .OOS )
204+
205+ for split in self .dataset :
206+ if split .startswith (Split .OOS ):
207+ continue
208+ n_classes_split = self .dataset .get_n_classes (split )
209+ if n_classes_split != self .n_classes :
210+ message = (
211+ f"Number of classes in split '{ split } ' doesn't match initial number of classes "
212+ f"({ n_classes_split } != { self .n_classes } )"
213+ )
214+ raise ValueError (message )
0 commit comments