Skip to content

Commit 32e5e5f

Browse files
committed
Add splitting for OOS samples
1 parent 6cfc891 commit 32e5e5f

File tree

10 files changed

+105
-52
lines changed

10 files changed

+105
-52
lines changed

autointent/context/data_handler/_data_handler.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,40 +50,7 @@ def __init__(
5050

5151
self.n_classes = self.dataset.n_classes
5252

53-
if Split.TEST not in self.dataset:
54-
self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset(
55-
self.dataset,
56-
split=Split.TRAIN,
57-
test_size=0.2,
58-
random_seed=random_seed,
59-
)
60-
61-
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
62-
self.dataset,
63-
split=Split.TRAIN,
64-
test_size=0.5,
65-
random_seed=random_seed,
66-
)
67-
self.dataset.pop(Split.TRAIN)
68-
69-
for idx in range(2):
70-
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
71-
self.dataset,
72-
split=f"{Split.TRAIN}_{idx}",
73-
test_size=0.2,
74-
random_seed=random_seed,
75-
)
76-
77-
for split in self.dataset:
78-
if split == Split.OOS:
79-
continue
80-
n_classes_split = self.dataset.get_n_classes(split)
81-
if n_classes_split != self.n_classes:
82-
message = (
83-
f"Number of classes in split '{split}' doesn't match initial number of classes "
84-
f"({n_classes_split} != {self.n_classes})"
85-
)
86-
raise ValueError(message)
53+
self._split(random_seed)
8754

8855
self.regexp_patterns = [
8956
RegexPatterns(
@@ -162,14 +129,15 @@ def test_labels(self, idx: int | None = None) -> list[LabelType]:
162129
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
163130
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
164131

165-
def oos_utterances(self) -> list[str]:
132+
def oos_utterances(self, idx: int | None = None) -> list[str]:
166133
"""
167134
Get the out-of-scope utterances.
168135
169136
:return: List of out-of-scope utterances if available, otherwise an empty list.
170137
"""
171138
if self.has_oos_samples():
172-
return cast(list[str], self.dataset[Split.OOS][self.dataset.utterance_feature])
139+
split = f"{Split.OOS}_{idx}" if idx is not None else Split.OOS
140+
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
173141
return []
174142

175143
def has_oos_samples(self) -> bool:
@@ -178,7 +146,7 @@ def has_oos_samples(self) -> bool:
178146
179147
:return: True if there are out-of-scope samples.
180148
"""
181-
return Split.OOS in self.dataset
149+
return any(split.startswith(Split.OOS) for split in self.dataset)
182150

183151
def dump(self) -> dict[str, list[dict[str, Any]]]:
184152
"""
@@ -187,3 +155,60 @@ def dump(self) -> dict[str, list[dict[str, Any]]]:
187155
:return: Dataset dump.
188156
"""
189157
return self.dataset.dump()
158+
159+
def _split(self, random_seed: int) -> None:
160+
if Split.TEST not in self.dataset:
161+
self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset(
162+
self.dataset,
163+
split=Split.TRAIN,
164+
test_size=0.2,
165+
random_seed=random_seed,
166+
)
167+
168+
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
169+
self.dataset,
170+
split=Split.TRAIN,
171+
test_size=0.5,
172+
random_seed=random_seed,
173+
)
174+
self.dataset.pop(Split.TRAIN)
175+
176+
for idx in range(2):
177+
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
178+
self.dataset,
179+
split=f"{Split.TRAIN}_{idx}",
180+
test_size=0.2,
181+
random_seed=random_seed,
182+
)
183+
184+
if self.has_oos_samples():
185+
self.dataset[f"{Split.OOS}_0"], self.dataset[f"{Split.OOS}_1"] = (
186+
self.dataset[Split.OOS]
187+
.train_test_split(
188+
test_size=0.2,
189+
shuffle=True,
190+
seed=random_seed,
191+
)
192+
.values()
193+
)
194+
self.dataset[f"{Split.OOS}_1"], self.dataset[f"{Split.OOS}_2"] = (
195+
self.dataset[f"{Split.OOS}_1"]
196+
.train_test_split(
197+
test_size=0.5,
198+
shuffle=True,
199+
seed=random_seed,
200+
)
201+
.values()
202+
)
203+
self.dataset.pop(Split.OOS)
204+
205+
for split in self.dataset:
206+
if split.startswith(Split.OOS):
207+
continue
208+
n_classes_split = self.dataset.get_n_classes(split)
209+
if n_classes_split != self.n_classes:
210+
message = (
211+
f"Number of classes in split '{split}' doesn't match initial number of classes "
212+
f"({n_classes_split} != {self.n_classes})"
213+
)
214+
raise ValueError(message)

autointent/context/optimization_info/_data_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ class ScorerArtifact(Artifact):
4242
train_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for train utterances")
4343
validation_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for validation utterances")
4444
test_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for test utterances")
45-
oos_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for out-of-scope utterances")
45+
oos_scores: dict[str, NDArray[np.float64]] | None = Field(
46+
None,
47+
description="Scorer outputs for out-of-scope utterances",
48+
)
4649

4750

4851
class PredictorArtifact(Artifact):

autointent/context/optimization_info/_optimization_info.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from dataclasses import dataclass, field
8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING, Any, Literal
99

1010
import numpy as np
1111
from numpy.typing import NDArray
@@ -174,13 +174,15 @@ def get_best_test_scores(self) -> NDArray[np.float64] | None:
174174
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
175175
return best_scorer_artifact.test_scores
176176

177-
def get_best_oos_scores(self) -> NDArray[np.float64] | None:
177+
def get_best_oos_scores(self, split: Literal["train", "validation", "test"]) -> NDArray[np.float64] | None:
178178
"""
179179
Retrieve the out-of-scope scores from the best scorer node.
180180
181181
:return: Out-of-scope scores as a numpy array.
182182
"""
183183
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
184+
if best_scorer_artifact.oos_scores is not None:
185+
return best_scorer_artifact.oos_scores[split]
184186
return best_scorer_artifact.oos_scores
185187

186188
def dump_evaluation_results(self) -> dict[str, dict[str, list[float]]]:

autointent/modules/_regexp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def score(
130130
assets = {
131131
"test_matches": list(self.predict(context.data_handler.test_utterances())),
132132
"oos_matches": None
133-
if context.data_handler.has_oos_samples()
134-
else self.predict(context.data_handler.oos_utterances()),
133+
if not context.data_handler.has_oos_samples()
134+
else self.predict(context.data_handler.oos_utterances(2)),
135135
}
136136
if assets["test_matches"] is None:
137137
msg = "no matches found"

autointent/modules/prediction/_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,18 @@ def get_prediction_evaluation_data(
8282
elif split == "validation":
8383
labels = np.array(context.data_handler.validation_labels(1))
8484
scores = context.optimization_info.get_best_validation_scores()
85-
else:
85+
elif split == "test":
8686
labels = np.array(context.data_handler.test_labels())
8787
scores = context.optimization_info.get_best_test_scores()
88+
else:
89+
message = f"Invalid split '{split}' provided. Expected one of 'train', 'validation', or 'test'."
90+
raise ValueError(message)
8891

8992
if scores is None:
9093
message = f"No '{split}' scores found in the optimization info"
9194
raise ValueError(message)
9295

93-
oos_scores = context.optimization_info.get_best_oos_scores()
96+
oos_scores = context.optimization_info.get_best_oos_scores(split)
9497
return_scores = scores
9598
if oos_scores is not None:
9699
oos_labels = (

autointent/modules/retrieval/_vectordb.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,12 @@ def score(
127127
if split == "validation":
128128
utterances = context.data_handler.validation_utterances(0)
129129
labels = context.data_handler.validation_labels(0)
130-
else:
130+
elif split == "test":
131131
utterances = context.data_handler.test_utterances()
132132
labels = context.data_handler.test_labels()
133+
else:
134+
message = f"Invalid split '{split}' provided. Expected one of 'validation', or 'test'."
135+
raise ValueError(message)
133136
predictions, _, _ = self.vector_index.query(utterances, self.k)
134137
return metric_fn(labels, predictions)
135138

autointent/modules/scoring/_base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy.typing as npt
88

99
from autointent import Context
10+
from autointent.context.data_handler import Split
1011
from autointent.context.optimization_info import ScorerArtifact
1112
from autointent.metrics import ScoringMetricFn
1213
from autointent.modules import Module
@@ -36,15 +37,22 @@ def score(
3637
if split == "validation":
3738
utterances = context.data_handler.validation_utterances(0)
3839
labels = context.data_handler.validation_labels(0)
39-
else:
40+
elif split == "test":
4041
utterances = context.data_handler.test_utterances()
4142
labels = context.data_handler.test_labels()
43+
else:
44+
message = f"Invalid split '{split}' provided. Expected one of 'validation', or 'test'."
45+
raise ValueError(message)
4246

4347
scores = self.predict(utterances)
4448

4549
self._oos_scores = None
4650
if context.data_handler.has_oos_samples():
47-
self._oos_scores = self.predict(context.data_handler.oos_utterances())
51+
self._oos_scores = {
52+
Split.TRAIN: self.predict(context.data_handler.oos_utterances(0)),
53+
Split.VALIDATION: self.predict(context.data_handler.oos_utterances(1)),
54+
Split.TEST: self.predict(context.data_handler.oos_utterances(2)),
55+
}
4856

4957
self._train_scores = self.predict(context.data_handler.train_utterances(1))
5058
self._validation_scores = self.predict(context.data_handler.validation_utterances(1))

tests/assets/data/clinc_subset.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@
145145
},
146146
{
147147
"utterance": "what size wipers does this car take"
148+
},
149+
{
150+
"utterance": "where is the dipstick"
151+
},
152+
{
153+
"utterance": "how much is 1 share of aapl"
154+
},
155+
{
156+
"utterance": "how is glue made"
148157
}
149158
]
150159
}

tests/modules/prediction/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def multiclass_fit_data(dataset):
2020
scorer = KNNScorer(**knn_params)
2121

2222
scorer.fit(data_handler.train_utterances(1), data_handler.train_labels(1))
23-
scores = scorer.predict(data_handler.validation_utterances(1) + data_handler.oos_utterances())
24-
labels = data_handler.validation_labels(1) + [-1] * len(data_handler.oos_utterances())
23+
scores = scorer.predict(data_handler.validation_utterances(1) + data_handler.oos_utterances(1))
24+
labels = data_handler.validation_labels(1) + [-1] * len(data_handler.oos_utterances(1))
2525
return scores, labels
2626

2727

@@ -40,6 +40,6 @@ def multilabel_fit_data(dataset):
4040
scorer = KNNScorer(**knn_params)
4141

4242
scorer.fit(data_handler.train_utterances(1), data_handler.train_labels(1))
43-
scores = scorer.predict(data_handler.validation_utterances(1) + data_handler.oos_utterances())
44-
labels = data_handler.validation_labels(1) + [[0] * data_handler.n_classes] * len(data_handler.oos_utterances())
43+
scores = scorer.predict(data_handler.validation_utterances(1) + data_handler.oos_utterances(1))
44+
labels = data_handler.validation_labels(1) + [[0] * data_handler.n_classes] * len(data_handler.oos_utterances(1))
4545
return scores, labels

tests/modules/prediction/test_tunable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_multilabel(multilabel_fit_data):
2121
predictor.fit(*multilabel_fit_data)
2222
scores = np.array([[0.2, 0.9, 0], [0.8, 0, 0.6], [0, 0.4, 0.7]])
2323
predictions = predictor.predict(scores)
24-
desired = np.array([[0, 1, 0], [1, 0, 1], [0, 0, 1]])
24+
desired = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 1]])
2525

2626
np.testing.assert_array_equal(predictions, desired)
2727

0 commit comments

Comments
 (0)