@@ -48,20 +48,9 @@ def __init__(
4848 if self .dataset .multilabel :
4949 self .dataset = self .dataset .encode_labels ()
5050
51- if Split .TEST not in self .dataset :
52- logger .info ("Splitting dataset into train and test splits" )
53- self .dataset = split_dataset (self .dataset , random_seed = random_seed )
51+ self .n_classes = self .dataset .n_classes
5452
55- for split in self .dataset :
56- if split == Split .OOS :
57- continue
58- n_classes_split = self .dataset .get_n_classes (split )
59- if n_classes_split != self .n_classes :
60- message = (
61- f"Number of classes in split '{ split } ' doesn't match initial number of classes "
62- f"({ n_classes_split } != { self .n_classes } )"
63- )
64- raise ValueError (message )
53+ self ._split (random_seed )
6554
6655 self .regexp_patterns = [
6756 RegexPatterns (
@@ -86,60 +75,104 @@ def multilabel(self) -> bool:
8675 """
8776 return self .dataset .multilabel
8877
89- @property
90- def n_classes (self ) -> int :
78+ def train_utterances (self , idx : int | None = None ) -> list [str ]:
79+ """
80+ Retrieve training utterances from the dataset.
81+
82+ If a specific training split index is provided, retrieves utterances
83+ from the indexed training split. Otherwise, retrieves utterances from
84+ the primary training split.
85+
86+ :param idx: Optional index for a specific training split.
87+ :return: List of training utterances.
9188 """
92- Get the number of classes in the dataset.
89+ split = f"{ Split .TRAIN } _{ idx } " if idx is not None else Split .TRAIN
90+ return cast (list [str ], self .dataset [split ][self .dataset .utterance_feature ])
9391
94- :return: Number of classes.
92+ def train_labels ( self , idx : int | None = None ) -> list [ LabelType ]:
9593 """
96- return self . dataset .n_classes
94+ Retrieve training labels from the dataset.
9795
98- @property
99- def train_utterances (self ) -> list [str ]:
96+ If a specific training split index is provided, retrieves labels
97+ from the indexed training split. Otherwise, retrieves labels from
98+ the primary training split.
99+
100+ :param idx: Optional index for a specific training split.
101+ :return: List of training labels.
100102 """
101- Get the training utterances.
103+ split = f"{ Split .TRAIN } _{ idx } " if idx is not None else Split .TRAIN
104+ return cast (list [LabelType ], self .dataset [split ][self .dataset .label_feature ])
102105
103- :return: List of training utterances.
106+ def validation_utterances ( self , idx : int | None = None ) -> list [ str ]:
104107 """
105- return cast ( list [ str ], self . dataset [ Split . TRAIN ][ self . dataset . utterance_feature ])
108+ Retrieve validation utterances from the dataset.
106109
107- @property
108- def train_labels (self ) -> list [LabelType ]:
110+ If a specific validation split index is provided, retrieves utterances
111+ from the indexed validation split. Otherwise, retrieves utterances from
112+ the primary validation split.
113+
114+ :param idx: Optional index for a specific validation split.
115+ :return: List of validation utterances.
109116 """
110- Get the training labels.
117+ split = f"{ Split .VALIDATION } _{ idx } " if idx is not None else Split .VALIDATION
118+ return cast (list [str ], self .dataset [split ][self .dataset .utterance_feature ])
111119
112- :return: List of training labels.
120+ def validation_labels ( self , idx : int | None = None ) -> list [ LabelType ]:
113121 """
114- return cast ( list [ LabelType ], self . dataset [ Split . TRAIN ][ self . dataset . label_feature ])
122+ Retrieve validation labels from the dataset.
115123
116- @property
117- def test_utterances (self ) -> list [str ]:
124+ If a specific validation split index is provided, retrieves labels
125+ from the indexed validation split. Otherwise, retrieves labels from
126+ the primary validation split.
127+
128+ :param idx: Optional index for a specific validation split.
129+ :return: List of validation labels.
130+ """
131+ split = f"{ Split .VALIDATION } _{ idx } " if idx is not None else Split .VALIDATION
132+ return cast (list [LabelType ], self .dataset [split ][self .dataset .label_feature ])
133+
134+ def test_utterances (self , idx : int | None = None ) -> list [str ]:
118135 """
119- Get the test utterances.
136+ Retrieve test utterances from the dataset .
120137
138+ If a specific test split index is provided, retrieves utterances
139+ from the indexed test split. Otherwise, retrieves utterances from
140+ the primary test split.
141+
142+ :param idx: Optional index for a specific test split.
121143 :return: List of test utterances.
122144 """
123- return cast (list [str ], self .dataset [Split .TEST ][self .dataset .utterance_feature ])
145+ split = f"{ Split .TEST } _{ idx } " if idx is not None else Split .TEST
146+ return cast (list [str ], self .dataset [split ][self .dataset .utterance_feature ])
124147
125- @property
126- def test_labels (self ) -> list [LabelType ]:
148+ def test_labels (self , idx : int | None = None ) -> list [LabelType ]:
127149 """
128- Get the test labels.
150+ Retrieve test labels from the dataset .
129151
152+ If a specific test split index is provided, retrieves labels
153+ from the indexed test split. Otherwise, retrieves labels from
154+ the primary test split.
155+
156+ :param idx: Optional index for a specific test split.
130157 :return: List of test labels.
131158 """
132- return cast (list [LabelType ], self .dataset [Split .TEST ][self .dataset .label_feature ])
159+ split = f"{ Split .TEST } _{ idx } " if idx is not None else Split .TEST
160+ return cast (list [LabelType ], self .dataset [split ][self .dataset .label_feature ])
133161
134- @property
135- def oos_utterances (self ) -> list [str ]:
162+ def oos_utterances (self , idx : int | None = None ) -> list [str ]:
136163 """
137- Get the out-of-scope utterances.
164+ Retrieve out-of-scope (OOS) utterances from the dataset .
138165
139- :return: List of out-of-scope utterances if available, otherwise an empty list.
166+ If the dataset contains out-of-scope samples, retrieves the utterances
167+ from the specified OOS split index (if provided) or the primary OOS split.
168+ Returns an empty list if no OOS samples are available in the dataset.
169+
170+ :param idx: Optional index for a specific OOS split.
171+ :return: List of out-of-scope utterances, or an empty list if unavailable.
140172 """
141173 if self .has_oos_samples ():
142- return cast (list [str ], self .dataset [Split .OOS ][self .dataset .utterance_feature ])
174+ split = f"{ Split .OOS } _{ idx } " if idx is not None else Split .OOS
175+ return cast (list [str ], self .dataset [split ][self .dataset .utterance_feature ])
143176 return []
144177
145178 def has_oos_samples (self ) -> bool :
@@ -148,7 +181,7 @@ def has_oos_samples(self) -> bool:
148181
149182 :return: True if there are out-of-scope samples.
150183 """
151- return Split .OOS in self .dataset
184+ return any ( split . startswith ( Split .OOS ) for split in self .dataset )
152185
153186 def dump (self ) -> dict [str , list [dict [str , Any ]]]:
154187 """
@@ -157,3 +190,60 @@ def dump(self) -> dict[str, list[dict[str, Any]]]:
157190 :return: Dataset dump.
158191 """
159192 return self .dataset .dump ()
193+
194+ def _split (self , random_seed : int ) -> None :
195+ if Split .TEST not in self .dataset :
196+ self .dataset [Split .TRAIN ], self .dataset [Split .TEST ] = split_dataset (
197+ self .dataset ,
198+ split = Split .TRAIN ,
199+ test_size = 0.2 ,
200+ random_seed = random_seed ,
201+ )
202+
203+ self .dataset [f"{ Split .TRAIN } _0" ], self .dataset [f"{ Split .TRAIN } _1" ] = split_dataset (
204+ self .dataset ,
205+ split = Split .TRAIN ,
206+ test_size = 0.5 ,
207+ random_seed = random_seed ,
208+ )
209+ self .dataset .pop (Split .TRAIN )
210+
211+ for idx in range (2 ):
212+ self .dataset [f"{ Split .TRAIN } _{ idx } " ], self .dataset [f"{ Split .VALIDATION } _{ idx } " ] = split_dataset (
213+ self .dataset ,
214+ split = f"{ Split .TRAIN } _{ idx } " ,
215+ test_size = 0.2 ,
216+ random_seed = random_seed ,
217+ )
218+
219+ if self .has_oos_samples ():
220+ self .dataset [f"{ Split .OOS } _0" ], self .dataset [f"{ Split .OOS } _1" ] = (
221+ self .dataset [Split .OOS ]
222+ .train_test_split (
223+ test_size = 0.2 ,
224+ shuffle = True ,
225+ seed = random_seed ,
226+ )
227+ .values ()
228+ )
229+ self .dataset [f"{ Split .OOS } _1" ], self .dataset [f"{ Split .OOS } _2" ] = (
230+ self .dataset [f"{ Split .OOS } _1" ]
231+ .train_test_split (
232+ test_size = 0.5 ,
233+ shuffle = True ,
234+ seed = random_seed ,
235+ )
236+ .values ()
237+ )
238+ self .dataset .pop (Split .OOS )
239+
240+ for split in self .dataset :
241+ if split .startswith (Split .OOS ):
242+ continue
243+ n_classes_split = self .dataset .get_n_classes (split )
244+ if n_classes_split != self .n_classes :
245+ message = (
246+ f"Number of classes in split '{ split } ' doesn't match initial number of classes "
247+ f"({ n_classes_split } != { self .n_classes } )"
248+ )
249+ raise ValueError (message )
0 commit comments