Skip to content

Commit e7a724a

Browse files
committed
remove oos utilities from everywhere
1 parent 0956f13 commit e7a724a

File tree

11 files changed

+31
-142
lines changed

11 files changed

+31
-142
lines changed

autointent/_dataset/_dataset.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77
from typing import Any, TypedDict
88

9-
from datasets import ClassLabel, Sequence, concatenate_datasets, get_dataset_config_names, load_dataset
9+
from datasets import ClassLabel, Sequence, get_dataset_config_names, load_dataset
1010
from datasets import Dataset as HFDataset
1111

1212
from autointent.custom_types import LabelType, Split
@@ -39,7 +39,7 @@ class Dataset(dict[str, HFDataset]):
3939

4040
def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: # noqa: ANN401
4141
"""
42-
Initialize the dataset and configure OOS split if applicable.
42+
Initialize the dataset.
4343
4444
:param args: Positional arguments to initialize the dataset.
4545
:param intents: List of intents associated with the dataset.
@@ -54,10 +54,6 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: #
5454
if self.multilabel:
5555
self._encode_labels()
5656

57-
oos_split = self._create_oos_split()
58-
if oos_split is not None:
59-
self[Split.OOS] = oos_split
60-
6157
@property
6258
def multilabel(self) -> bool:
6359
"""
@@ -144,7 +140,10 @@ def to_json(self, filepath: str | Path) -> None:
144140
145141
:param filepath: The path to the file where the JSON data will be saved.
146142
"""
147-
with Path(filepath).open("w") as file:
143+
path = Path(filepath)
144+
if not path.parent.exists():
145+
path.parent.mkdir(parents=True)
146+
with path.open("w") as file:
148147
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)
149148

150149
def push_to_hub(self, repo_id: str, private: bool = False) -> None:
@@ -204,15 +203,6 @@ def _encode_labels(self) -> "Dataset":
204203
self._encoded_labels = True
205204
return self
206205

207-
def _is_oos(self, sample: Sample) -> bool:
208-
"""
209-
Check if a sample is out-of-scope.
210-
211-
:param sample: The sample to check.
212-
:return: True if the sample is out-of-scope, False otherwise.
213-
"""
214-
return sample["label"] is None
215-
216206
def _to_multilabel(self, sample: Sample) -> Sample:
217207
"""
218208
Convert a sample's label to multilabel format.
@@ -241,20 +231,6 @@ def _encode_label(self, sample: Sample) -> Sample:
241231
sample["label"] = one_hot_label
242232
return sample
243233

244-
def _create_oos_split(self) -> HFDataset | None:
245-
"""
246-
Create an out-of-scope (OOS) split from the dataset.
247-
248-
:return: The OOS split if created, None otherwise.
249-
"""
250-
oos_splits = [split.filter(self._is_oos) for split in self.values()]
251-
oos_splits = [oos_split for oos_split in oos_splits if oos_split.num_rows]
252-
if oos_splits:
253-
for split_name, split in self.items():
254-
self[split_name] = split.filter(lambda sample: not self._is_oos(sample))
255-
return concatenate_datasets(oos_splits)
256-
return None
257-
258234
def _cast_label_feature(self) -> None:
259235
"""Cast the label feature of the dataset to the appropriate type."""
260236
for split_name, split in self.items():

autointent/_dataset/_validation.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,9 @@ def validate_dataset(self) -> "DatasetReader":
6666
]
6767
splits = [split for split in splits if split]
6868

69-
n_classes = [self._get_n_classes(split) for split in splits]
70-
if len(set(n_classes)) != 1:
71-
message = (
72-
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
73-
"Ensure all splits have the same number of classes."
74-
)
75-
raise ValueError(message)
76-
if not n_classes[0]:
77-
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
78-
raise ValueError(message)
69+
n_classes = self._validate_classes(splits)
7970

80-
self._validate_intents(n_classes[0])
71+
self._validate_intents(n_classes)
8172

8273
for split in splits:
8374
self._validate_split(split)
@@ -100,6 +91,20 @@ def _get_n_classes(self, split: list[Sample]) -> int:
10091
classes.add(label)
10192
return len(classes)
10293

94+
def _validate_classes(self, splits: list[list[Sample]]) -> int:
95+
"""Validate that each split has all classes."""
96+
n_classes = [self._get_n_classes(split) for split in splits]
97+
if len(set(n_classes)) != 1:
98+
message = (
99+
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
100+
"Ensure all splits have the same number of classes."
101+
)
102+
raise ValueError(message)
103+
if not n_classes[0]:
104+
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
105+
raise ValueError(message)
106+
return n_classes[0]
107+
103108
def _validate_intents(self, n_classes: int) -> "DatasetReader":
104109
"""
105110
Validate the intents by checking their IDs for sequential order.
@@ -132,7 +137,8 @@ def _validate_split(self, split: list[Sample]) -> "DatasetReader":
132137
intent_ids = {intent.id for intent in self.intents}
133138
for sample in split:
134139
message = (
135-
f"Sample with label {sample.label} references a non-existent intent ID. " f"Valid IDs are {intent_ids}."
140+
f"Sample with label {sample.label} and utterance {sample.utterance[:10]}... "
141+
f"references a non-existent intent ID. Valid IDs are {intent_ids}."
136142
)
137143
if isinstance(sample.label, int) and sample.label not in intent_ids:
138144
raise ValueError(message)

autointent/context/data_handler/_data_handler.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -153,30 +153,6 @@ def test_labels(self, idx: int | None = None) -> list[LabelType]:
153153
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
154154
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
155155

156-
def oos_utterances(self, idx: int | None = None) -> list[str]:
157-
"""
158-
Retrieve out-of-scope (OOS) utterances from the dataset.
159-
160-
If the dataset contains out-of-scope samples, retrieves the utterances
161-
from the specified OOS split index (if provided) or the primary OOS split.
162-
Returns an empty list if no OOS samples are available in the dataset.
163-
164-
:param idx: Optional index for a specific OOS split.
165-
:return: List of out-of-scope utterances, or an empty list if unavailable.
166-
"""
167-
if self.has_oos_samples():
168-
split = f"{Split.OOS}_{idx}" if idx is not None else Split.OOS
169-
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
170-
return []
171-
172-
def has_oos_samples(self) -> bool:
173-
"""
174-
Check if there are out-of-scope samples.
175-
176-
:return: True if there are out-of-scope samples.
177-
"""
178-
return any(split.startswith(Split.OOS) for split in self.dataset)
179-
180156
def dump(self, filepath: str | Path) -> None:
181157
"""
182158
Save the dataset splits and intents to a JSON file.
@@ -205,12 +181,7 @@ def _split(self, random_seed: int, split_train: bool) -> None:
205181
elif Split.VALIDATION in self.dataset:
206182
self._split_validation(random_seed)
207183

208-
if self.has_oos_samples():
209-
self._split_oos(random_seed)
210-
211184
for split in self.dataset:
212-
if split.startswith(Split.OOS):
213-
continue
214185
n_classes_split = self.dataset.get_n_classes(split)
215186
if n_classes_split != self.n_classes:
216187
message = (
@@ -280,24 +251,3 @@ def _split_test(self, test_size: float, random_seed: int) -> None:
280251
)
281252
self.dataset.pop(f"{Split.TEST}_0")
282253
self.dataset.pop(f"{Split.TEST}_1")
283-
284-
def _split_oos(self, random_seed: int) -> None:
285-
self.dataset[f"{Split.OOS}_0"], self.dataset[f"{Split.OOS}_1"] = (
286-
self.dataset[Split.OOS]
287-
.train_test_split(
288-
test_size=0.2,
289-
shuffle=True,
290-
seed=random_seed,
291-
)
292-
.values()
293-
)
294-
self.dataset[f"{Split.OOS}_1"], self.dataset[f"{Split.OOS}_2"] = (
295-
self.dataset[f"{Split.OOS}_1"]
296-
.train_test_split(
297-
test_size=0.5,
298-
shuffle=True,
299-
seed=random_seed,
300-
)
301-
.values()
302-
)
303-
self.dataset.pop(Split.OOS)

autointent/context/optimization_info/_data_models.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ class ScorerArtifact(Artifact):
4242
train_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for train utterances")
4343
validation_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for validation utterances")
4444
test_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for test utterances")
45-
oos_scores: dict[str, NDArray[np.float64]] | None = Field(
46-
None,
47-
description="Scorer outputs for out-of-scope utterances",
48-
)
4945

5046

5147
class DecisionArtifact(Artifact):

autointent/context/optimization_info/_optimization_info.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from dataclasses import dataclass, field
8-
from typing import TYPE_CHECKING, Any, Literal
8+
from typing import TYPE_CHECKING, Any
99

1010
import numpy as np
1111
from numpy.typing import NDArray
@@ -175,20 +175,6 @@ def get_best_test_scores(self) -> NDArray[np.float64] | None:
175175
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
176176
return best_scorer_artifact.test_scores
177177

178-
def get_best_oos_scores(self, split: Literal["train", "validation", "test"]) -> NDArray[np.float64] | None:
179-
"""
180-
Retrieve the out-of-scope scores from the best scorer node.
181-
182-
:param split: The data split for which to retrieve the OOS scores.
183-
Must be one of "train", "validation", or "test".
184-
:return: A numpy array containing OOS scores for the specified split,
185-
or `None` if no OOS scores are available.
186-
"""
187-
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
188-
if best_scorer_artifact.oos_scores is not None:
189-
return best_scorer_artifact.oos_scores[split]
190-
return best_scorer_artifact.oos_scores
191-
192178
def dump_evaluation_results(self) -> dict[str, Any]:
193179
"""
194180
Dump evaluation results for all nodes.

autointent/custom_types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,10 @@ class Split:
5858
:cvar str TRAIN: Training split.
5959
:cvar str VALIDATION: Validation split.
6060
:cvar str TEST: Testing split.
61-
:cvar str OOS: Out-of-scope split.
6261
:cvar str INTENTS: Intents split.
6362
"""
6463

6564
TRAIN = "train"
6665
VALIDATION = "validation"
6766
TEST = "test"
68-
OOS = "oos"
6967
INTENTS = "intents"

autointent/modules/abc/_decision.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,4 @@ def get_decision_evaluation_data(
9292
message = f"No '{split}' scores found in the optimization info"
9393
raise ValueError(message)
9494

95-
oos_scores = context.optimization_info.get_best_oos_scores(split)
96-
return_scores = scores
97-
if oos_scores is not None:
98-
oos_labels = (
99-
[[0] * context.get_n_classes()] * len(oos_scores) if context.is_multilabel() else [-1] * len(oos_scores) # type: ignore[list-item]
100-
)
101-
labels = np.concatenate([labels, np.array(oos_labels)])
102-
return_scores = np.concatenate([scores, oos_scores])
103-
104-
return labels.tolist(), return_scores # type: ignore[return-value]
95+
return labels.tolist(), scores # type: ignore[return-value]

autointent/modules/abc/_scoring.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from autointent import Context
99
from autointent.context.optimization_info import ScorerArtifact
10-
from autointent.custom_types import Split
1110
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
1211
from autointent.modules.abc import Module
1312

@@ -44,14 +43,6 @@ def score(
4443

4544
scores = self.predict(utterances)
4645

47-
self._oos_scores = None
48-
if context.data_handler.has_oos_samples():
49-
self._oos_scores = {
50-
Split.TRAIN: self.predict(context.data_handler.oos_utterances(0)),
51-
Split.VALIDATION: self.predict(context.data_handler.oos_utterances(1)),
52-
Split.TEST: self.predict(context.data_handler.oos_utterances(2)),
53-
}
54-
5546
self._train_scores = self.predict(context.data_handler.train_utterances(1))
5647
self._validation_scores = self.predict(context.data_handler.validation_utterances(1))
5748
self._test_scores = self.predict(context.data_handler.test_utterances())
@@ -63,13 +54,12 @@ def get_assets(self) -> ScorerArtifact:
6354
"""
6455
Retrieve assets generated during scoring.
6556
66-
:return: ScorerArtifact containing test scores and out-of-scope (OOS) scores.
57+
:return: ScorerArtifact containing test, validation and test scores.
6758
"""
6859
return ScorerArtifact(
6960
train_scores=self._train_scores,
7061
validation_scores=self._validation_scores,
7162
test_scores=self._test_scores,
72-
oos_scores=self._oos_scores,
7363
)
7464

7565
@abstractmethod

autointent/modules/regexp/_regexp.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ def score(
126126
# whether or not to omit utterances on next stages if they were detected with regexp module
127127
assets = {
128128
"test_matches": list(self.predict(context.data_handler.test_utterances())),
129-
"oos_matches": None
130-
if not context.data_handler.has_oos_samples()
131-
else self.predict(context.data_handler.oos_utterances(2)),
132129
}
133130
if assets["test_matches"] is None:
134131
msg = "no matches found"

tests/context/datahandler/test_data_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ def test_dataset_initialization(mapping):
151151
{"train": mock_split(), "validation": mock_split(), "validation_0": mock_split()},
152152
{"train": mock_split(), "validation": mock_split(), "validation_1": mock_split()},
153153
{"train": mock_split(), "validation": mock_split(), "validation_0": mock_split(), "validation_1": mock_split()},
154-
{"train": mock_split(), "oos": mock_split()},
155154
],
156155
)
157156
def test_dataset_validation(mapping):

0 commit comments

Comments
 (0)