Skip to content

Commit 3380de8

Browse files
authored
Add validation for scoring and prediction (#61)
1 parent 8a98e5a commit 3380de8

File tree

23 files changed

+441
-151
lines changed

23 files changed

+441
-151
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._data_handler import DataHandler
2-
from ._dataset import Dataset
2+
from ._dataset import Dataset, Split
33
from ._schemas import Intent, Sample, Tag
44

5-
__all__ = ["DataHandler", "Dataset", "Intent", "Sample", "Tag"]
5+
__all__ = ["DataHandler", "Dataset", "Intent", "Sample", "Split", "Tag"]

autointent/context/data_handler/_data_handler.py

Lines changed: 132 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,9 @@ def __init__(
4848
if self.dataset.multilabel:
4949
self.dataset = self.dataset.encode_labels()
5050

51-
if Split.TEST not in self.dataset:
52-
logger.info("Splitting dataset into train and test splits")
53-
self.dataset = split_dataset(self.dataset, random_seed=random_seed)
51+
self.n_classes = self.dataset.n_classes
5452

55-
for split in self.dataset:
56-
if split == Split.OOS:
57-
continue
58-
n_classes_split = self.dataset.get_n_classes(split)
59-
if n_classes_split != self.n_classes:
60-
message = (
61-
f"Number of classes in split '{split}' doesn't match initial number of classes "
62-
f"({n_classes_split} != {self.n_classes})"
63-
)
64-
raise ValueError(message)
53+
self._split(random_seed)
6554

6655
self.regexp_patterns = [
6756
RegexPatterns(
@@ -86,60 +75,104 @@ def multilabel(self) -> bool:
8675
"""
8776
return self.dataset.multilabel
8877

89-
@property
90-
def n_classes(self) -> int:
78+
def train_utterances(self, idx: int | None = None) -> list[str]:
79+
"""
80+
Retrieve training utterances from the dataset.
81+
82+
If a specific training split index is provided, retrieves utterances
83+
from the indexed training split. Otherwise, retrieves utterances from
84+
the primary training split.
85+
86+
:param idx: Optional index for a specific training split.
87+
:return: List of training utterances.
9188
"""
92-
Get the number of classes in the dataset.
89+
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
90+
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
9391

94-
:return: Number of classes.
92+
def train_labels(self, idx: int | None = None) -> list[LabelType]:
9593
"""
96-
return self.dataset.n_classes
94+
Retrieve training labels from the dataset.
9795
98-
@property
99-
def train_utterances(self) -> list[str]:
96+
If a specific training split index is provided, retrieves labels
97+
from the indexed training split. Otherwise, retrieves labels from
98+
the primary training split.
99+
100+
:param idx: Optional index for a specific training split.
101+
:return: List of training labels.
100102
"""
101-
Get the training utterances.
103+
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
104+
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
102105

103-
:return: List of training utterances.
106+
def validation_utterances(self, idx: int | None = None) -> list[str]:
104107
"""
105-
return cast(list[str], self.dataset[Split.TRAIN][self.dataset.utterance_feature])
108+
Retrieve validation utterances from the dataset.
106109
107-
@property
108-
def train_labels(self) -> list[LabelType]:
110+
If a specific validation split index is provided, retrieves utterances
111+
from the indexed validation split. Otherwise, retrieves utterances from
112+
the primary validation split.
113+
114+
:param idx: Optional index for a specific validation split.
115+
:return: List of validation utterances.
109116
"""
110-
Get the training labels.
117+
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
118+
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
111119

112-
:return: List of training labels.
120+
def validation_labels(self, idx: int | None = None) -> list[LabelType]:
113121
"""
114-
return cast(list[LabelType], self.dataset[Split.TRAIN][self.dataset.label_feature])
122+
Retrieve validation labels from the dataset.
115123
116-
@property
117-
def test_utterances(self) -> list[str]:
124+
If a specific validation split index is provided, retrieves labels
125+
from the indexed validation split. Otherwise, retrieves labels from
126+
the primary validation split.
127+
128+
:param idx: Optional index for a specific validation split.
129+
:return: List of validation labels.
130+
"""
131+
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
132+
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
133+
134+
def test_utterances(self, idx: int | None = None) -> list[str]:
118135
"""
119-
Get the test utterances.
136+
Retrieve test utterances from the dataset.
120137
138+
If a specific test split index is provided, retrieves utterances
139+
from the indexed test split. Otherwise, retrieves utterances from
140+
the primary test split.
141+
142+
:param idx: Optional index for a specific test split.
121143
:return: List of test utterances.
122144
"""
123-
return cast(list[str], self.dataset[Split.TEST][self.dataset.utterance_feature])
145+
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
146+
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
124147

125-
@property
126-
def test_labels(self) -> list[LabelType]:
148+
def test_labels(self, idx: int | None = None) -> list[LabelType]:
127149
"""
128-
Get the test labels.
150+
Retrieve test labels from the dataset.
129151
152+
If a specific test split index is provided, retrieves labels
153+
from the indexed test split. Otherwise, retrieves labels from
154+
the primary test split.
155+
156+
:param idx: Optional index for a specific test split.
130157
:return: List of test labels.
131158
"""
132-
return cast(list[LabelType], self.dataset[Split.TEST][self.dataset.label_feature])
159+
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
160+
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
133161

134-
@property
135-
def oos_utterances(self) -> list[str]:
162+
def oos_utterances(self, idx: int | None = None) -> list[str]:
136163
"""
137-
Get the out-of-scope utterances.
164+
Retrieve out-of-scope (OOS) utterances from the dataset.
138165
139-
:return: List of out-of-scope utterances if available, otherwise an empty list.
166+
If the dataset contains out-of-scope samples, retrieves the utterances
167+
from the specified OOS split index (if provided) or the primary OOS split.
168+
Returns an empty list if no OOS samples are available in the dataset.
169+
170+
:param idx: Optional index for a specific OOS split.
171+
:return: List of out-of-scope utterances, or an empty list if unavailable.
140172
"""
141173
if self.has_oos_samples():
142-
return cast(list[str], self.dataset[Split.OOS][self.dataset.utterance_feature])
174+
split = f"{Split.OOS}_{idx}" if idx is not None else Split.OOS
175+
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
143176
return []
144177

145178
def has_oos_samples(self) -> bool:
@@ -148,7 +181,7 @@ def has_oos_samples(self) -> bool:
148181
149182
:return: True if there are out-of-scope samples.
150183
"""
151-
return Split.OOS in self.dataset
184+
return any(split.startswith(Split.OOS) for split in self.dataset)
152185

153186
def dump(self) -> dict[str, list[dict[str, Any]]]:
154187
"""
@@ -157,3 +190,60 @@ def dump(self) -> dict[str, list[dict[str, Any]]]:
157190
:return: Dataset dump.
158191
"""
159192
return self.dataset.dump()
193+
194+
def _split(self, random_seed: int) -> None:
195+
if Split.TEST not in self.dataset:
196+
self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset(
197+
self.dataset,
198+
split=Split.TRAIN,
199+
test_size=0.2,
200+
random_seed=random_seed,
201+
)
202+
203+
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
204+
self.dataset,
205+
split=Split.TRAIN,
206+
test_size=0.5,
207+
random_seed=random_seed,
208+
)
209+
self.dataset.pop(Split.TRAIN)
210+
211+
for idx in range(2):
212+
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
213+
self.dataset,
214+
split=f"{Split.TRAIN}_{idx}",
215+
test_size=0.2,
216+
random_seed=random_seed,
217+
)
218+
219+
if self.has_oos_samples():
220+
self.dataset[f"{Split.OOS}_0"], self.dataset[f"{Split.OOS}_1"] = (
221+
self.dataset[Split.OOS]
222+
.train_test_split(
223+
test_size=0.2,
224+
shuffle=True,
225+
seed=random_seed,
226+
)
227+
.values()
228+
)
229+
self.dataset[f"{Split.OOS}_1"], self.dataset[f"{Split.OOS}_2"] = (
230+
self.dataset[f"{Split.OOS}_1"]
231+
.train_test_split(
232+
test_size=0.5,
233+
shuffle=True,
234+
seed=random_seed,
235+
)
236+
.values()
237+
)
238+
self.dataset.pop(Split.OOS)
239+
240+
for split in self.dataset:
241+
if split.startswith(Split.OOS):
242+
continue
243+
n_classes_split = self.dataset.get_n_classes(split)
244+
if n_classes_split != self.n_classes:
245+
message = (
246+
f"Number of classes in split '{split}' doesn't match initial number of classes "
247+
f"({n_classes_split} != {self.n_classes})"
248+
)
249+
raise ValueError(message)

autointent/context/data_handler/_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def multilabel(self) -> bool:
8181
8282
:return: True if the dataset is multilabel, False otherwise.
8383
"""
84-
return isinstance(self[Split.TRAIN].features[self.label_feature], Sequence)
84+
split = Split.TRAIN if Split.TRAIN in self else f"{Split.TRAIN}_0"
85+
return isinstance(self[split].features[self.label_feature], Sequence)
8586

8687
@cached_property
8788
def n_classes(self) -> int:

autointent/context/data_handler/_stratification.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.model_selection import train_test_split
1313
from skmultilearn.model_selection import IterativeStratification
1414

15-
from ._dataset import Dataset, Split
15+
from ._dataset import Dataset
1616

1717

1818
class StratifiedSplitter:
@@ -44,7 +44,7 @@ def __init__(
4444
self.random_seed = random_seed
4545
self.shuffle = shuffle
4646

47-
def __call__(self, dataset: HFDataset, multilabel: bool) -> tuple[Dataset, Dataset]:
47+
def __call__(self, dataset: HFDataset, multilabel: bool) -> tuple[HFDataset, HFDataset]:
4848
"""
4949
Split the dataset into training and testing subsets.
5050
@@ -73,21 +73,32 @@ def _split_multilabel(self, dataset: HFDataset) -> Sequence[npt.NDArray[np.int_]
7373
return next(splitter.split(np.arange(len(dataset)), np.array(dataset[self.label_feature])))
7474

7575

76-
def split_dataset(dataset: Dataset, random_seed: int) -> Dataset:
76+
def split_dataset(
77+
dataset: Dataset,
78+
split: str,
79+
test_size: float,
80+
random_seed: int,
81+
) -> tuple[HFDataset, HFDataset]:
7782
"""
7883
Split a Dataset object into training and testing subsets.
7984
8085
This function uses the StratifiedSplitter to perform stratified splitting
8186
while preserving the distribution of labels.
8287
8388
:param dataset: The dataset to be split, which must include training data.
89+
:param split: The specific data split to be divided, e.g., "train" or
90+
another split within the dataset.
91+
:param test_size: Proportion of the dataset to include in the test split.
92+
Should be a float value between 0.0 and 1.0, where 0.0
93+
means no data will be assigned to the test set, and 1.0
94+
means all data will be assigned to the test set. For example,
95+
a value of 0.2 indicates 20% of the data will be used for testing.
8496
:param random_seed: Seed for random number generation to ensure reproducibility.
85-
:return: The input dataset with training and testing splits.
97+
:return: A tuple containing two subsets of the selected split.
8698
"""
8799
splitter = StratifiedSplitter(
88-
test_size=0.25,
100+
test_size=test_size,
89101
label_feature=dataset.label_feature,
90102
random_seed=random_seed,
91103
)
92-
dataset[Split.TRAIN], dataset[Split.TEST] = splitter(dataset[Split.TRAIN], dataset.multilabel)
93-
return dataset
104+
return splitter(dataset[split], dataset.multilabel)

autointent/context/data_handler/_validation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ class DatasetReader(BaseModel):
1717
"""
1818

1919
train: list[Sample]
20-
validation: list[Sample] = []
2120
test: list[Sample] = []
2221
intents: list[Intent] = []
2322

@@ -30,7 +29,7 @@ def validate_dataset(self) -> Self:
3029
:return: The validated DatasetReader instance.
3130
"""
3231
self._validate_intents()
33-
for split in [self.train, self.validation, self.test]:
32+
for split in [self.train, self.test]:
3433
self._validate_split(split)
3534
return self
3635

autointent/context/optimization_info/_data_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,13 @@ class ScorerArtifact(Artifact):
3939
"""
4040

4141
model_config = ConfigDict(arbitrary_types_allowed=True)
42+
train_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for train utterances")
43+
validation_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for validation utterances")
4244
test_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for test utterances")
43-
oos_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for out-of-scope utterances")
45+
oos_scores: dict[str, NDArray[np.float64]] | None = Field(
46+
None,
47+
description="Scorer outputs for out-of-scope utterances",
48+
)
4449

4550

4651
class PredictorArtifact(Artifact):

autointent/context/optimization_info/_optimization_info.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from dataclasses import dataclass, field
8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING, Any, Literal
99

1010
import numpy as np
1111
from numpy.typing import NDArray
@@ -147,6 +147,24 @@ def get_best_embedder(self) -> str:
147147
best_retriever_artifact: RetrieverArtifact = self._get_best_artifact(node_type=NodeType.retrieval) # type: ignore[assignment]
148148
return best_retriever_artifact.embedder_name
149149

150+
def get_best_train_scores(self) -> NDArray[np.float64] | None:
151+
"""
152+
Retrieve the train scores from the best scorer node.
153+
154+
:return: Train scores as a numpy array.
155+
"""
156+
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
157+
return best_scorer_artifact.train_scores
158+
159+
def get_best_validation_scores(self) -> NDArray[np.float64] | None:
160+
"""
161+
Retrieve the validation scores from the best scorer node.
162+
163+
:return: Validation scores as a numpy array.
164+
"""
165+
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
166+
return best_scorer_artifact.validation_scores
167+
150168
def get_best_test_scores(self) -> NDArray[np.float64] | None:
151169
"""
152170
Retrieve the test scores from the best scorer node.
@@ -156,13 +174,18 @@ def get_best_test_scores(self) -> NDArray[np.float64] | None:
156174
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
157175
return best_scorer_artifact.test_scores
158176

159-
def get_best_oos_scores(self) -> NDArray[np.float64] | None:
177+
def get_best_oos_scores(self, split: Literal["train", "validation", "test"]) -> NDArray[np.float64] | None:
160178
"""
161179
Retrieve the out-of-scope scores from the best scorer node.
162180
163-
:return: Out-of-scope scores as a numpy array.
181+
:param split: The data split for which to retrieve the OOS scores.
182+
Must be one of "train", "validation", or "test".
183+
:return: A numpy array containing OOS scores for the specified split,
184+
or `None` if no OOS scores are available.
164185
"""
165186
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
187+
if best_scorer_artifact.oos_scores is not None:
188+
return best_scorer_artifact.oos_scores[split]
166189
return best_scorer_artifact.oos_scores
167190

168191
def dump_evaluation_results(self) -> dict[str, dict[str, list[float]]]:

0 commit comments

Comments
 (0)