66from pathlib import Path
77from typing import Any , TypedDict
88
9- from datasets import ClassLabel , Sequence , concatenate_datasets , get_dataset_config_names , load_dataset
109from datasets import Dataset as HFDataset
10+ from datasets import Sequence , get_dataset_config_names , load_dataset
1111
12- from autointent .custom_types import LabelType , Split
12+ from autointent .custom_types import LabelWithOOS , Split
1313from autointent .schemas import Intent , Tag
1414
1515
1616class Sample (TypedDict ):
1717 """
1818 Typed dictionary representing a dataset sample.
1919
20- :param str utterance: The text of the utterance.
21- :param LabelType | None label: The label associated with the utterance, or None if out-of-scope.
20+ :param utterance: The text of the utterance.
21+ :param label: The label associated with the utterance, or None if out-of-scope.
2222 """
2323
2424 utterance : str
25- label : LabelType | None
25+ label : LabelWithOOS
2626
2727
2828class Dataset (dict [str , HFDataset ]):
@@ -39,7 +39,7 @@ class Dataset(dict[str, HFDataset]):
3939
4040 def __init__ (self , * args : Any , intents : list [Intent ], ** kwargs : Any ) -> None : # noqa: ANN401
4141 """
42- Initialize the dataset and configure OOS split if applicable .
42+ Initialize the dataset.
4343
4444 :param args: Positional arguments to initialize the dataset.
4545 :param intents: List of intents associated with the dataset.
@@ -49,15 +49,6 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: #
4949
5050 self .intents = intents
5151
52- self ._encoded_labels = False
53-
54- if self .multilabel :
55- self ._encode_labels ()
56-
57- oos_split = self ._create_oos_split ()
58- if oos_split is not None :
59- self [Split .OOS ] = oos_split
60-
6152 @property
6253 def multilabel (self ) -> bool :
6354 """
@@ -125,7 +116,6 @@ def to_multilabel(self) -> "Dataset":
125116 """
126117 for split_name , split in self .items ():
127118 self [split_name ] = split .map (self ._to_multilabel )
128- self ._encode_labels ()
129119 return self
130120
131121 def to_dict (self ) -> dict [str , list [dict [str , Any ]]]:
@@ -184,38 +174,15 @@ def get_n_classes(self, split: str) -> int:
184174 """
185175 classes = set ()
186176 for label in self [split ][self .label_feature ]:
187- match ( label , self . _encoded_labels ) :
188- case ( int (), _ ):
177+ match label :
178+ case int ():
189179 classes .add (label )
190- case (list (), False ):
191- for label_ in label :
192- classes .add (label_ )
193- case (list (), True ):
180+ case list ():
194181 for idx , label_ in enumerate (label ):
195182 if label_ :
196183 classes .add (idx )
197184 return len (classes )
198185
199- def _encode_labels (self ) -> "Dataset" :
200- """
201- Encode dataset labels into one-hot or multilabel format.
202-
203- :return: Self, with labels encoded.
204- """
205- for split_name , split in self .items ():
206- self [split_name ] = split .map (self ._encode_label )
207- self ._encoded_labels = True
208- return self
209-
210- def _is_oos (self , sample : Sample ) -> bool :
211- """
212- Check if a sample is out-of-scope.
213-
214- :param sample: The sample to check.
215- :return: True if the sample is out-of-scope, False otherwise.
216- """
217- return sample ["label" ] is None
218-
219186 def _to_multilabel (self , sample : Sample ) -> Sample :
220187 """
221188 Convert a sample's label to multilabel format.
@@ -224,50 +191,7 @@ def _to_multilabel(self, sample: Sample) -> Sample:
224191 :return: Sample with label in multilabel format.
225192 """
226193 if isinstance (sample ["label" ], int ):
227- sample ["label" ] = [sample ["label" ]]
228- return sample
229-
230- def _encode_label (self , sample : Sample ) -> Sample :
231- """
232- Encode a sample's label as a one-hot vector.
233-
234- :param sample: The sample to encode.
235- :return: Sample with encoded label.
236- """
237- one_hot_label = [0 ] * self .n_classes
238- match sample ["label" ]:
239- case int ():
240- one_hot_label [sample ["label" ]] = 1
241- case list ():
242- for idx in sample ["label" ]:
243- one_hot_label [idx ] = 1
244- sample ["label" ] = one_hot_label
194+ ohe_vector = [0 ] * self .n_classes
195+ ohe_vector [sample ["label" ]] = 1
196+ sample ["label" ] = ohe_vector
245197 return sample
246-
247- def _create_oos_split (self ) -> HFDataset | None :
248- """
249- Create an out-of-scope (OOS) split from the dataset.
250-
251- :return: The OOS split if created, None otherwise.
252- """
253- oos_splits = [split .filter (self ._is_oos ) for split in self .values ()]
254- oos_splits = [oos_split for oos_split in oos_splits if oos_split .num_rows ]
255- if oos_splits :
256- for split_name , split in self .items ():
257- self [split_name ] = split .filter (lambda sample : not self ._is_oos (sample ))
258- return concatenate_datasets (oos_splits )
259- return None
260-
261- def _cast_label_feature (self ) -> None :
262- """Cast the label feature of the dataset to the appropriate type."""
263- for split_name , split in self .items ():
264- new_features = split .features .copy ()
265- if self .multilabel :
266- new_features [self .label_feature ] = Sequence (
267- ClassLabel (num_classes = self .n_classes ),
268- )
269- else :
270- new_features [self .label_feature ] = ClassLabel (
271- num_classes = self .n_classes ,
272- )
273- self [split_name ] = split .cast (new_features )
0 commit comments