Skip to content

Commit 1824ce3

Browse files
authored
Refactor/oos handling (#102)
* remove oos utilities from everywhere * stage progress * DataHandler: add support of splits with oos data * fix codestyle * fix wrong type annotation * try to add proper multilabel case handling to stratified splitter * add checking if loaded data already one hot encoded * stage progress on getting rid of handling anything except from ohe label for multi-label case * continue * continue * update stratifyer * minor bug fix * fix typing * update test data * update data handler a liitle bit * update test_nli_transformer * bug fix in test data * add oos, multilabel and inputs validation to decision modules * fix codestyle * minor bug fix * add oos handling to metrics * bug fix and update callback test * update data_handler test * update test for stratification * update description generation utility and corresponding tests * bug fix in test * add test for oos handling in metrics functions * fix oos handling in metrics * forgot to commit it earlier * minor refactoring of knn * fix and update tests for decision modules * add validation for supporting multi-class problem * update tests for scoring modules * update how data_handler reads intent descriptions * fix adaptive decision and add test on loading and dumping * fix decision roc_auc and how labels are restores during auto-configuration * fix some metric * minor bug fix and update test for inference * fix codestyle * stage progress on type fixing * finish fixing typing * update test for sklearn * fix user guides * fix advanced user guide on datasets * move data-related tests to a separate directory * add oos handling test * fix codestyle * update doctests for decision modules * remove clinc script * move exceptions to a separate submodule * fix imports * fix codestyle * remove unnecessary comments * fix tests for threshold and tunable modules
1 parent 3d18626 commit 1824ce3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1260
-1092
lines changed

autointent/_dataset/_dataset.py

Lines changed: 16 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@
66
from pathlib import Path
77
from typing import Any, TypedDict
88

9-
from datasets import ClassLabel, Sequence, concatenate_datasets, get_dataset_config_names, load_dataset
109
from datasets import Dataset as HFDataset
10+
from datasets import Sequence, get_dataset_config_names, load_dataset
1111

12-
from autointent.custom_types import LabelType, Split
12+
from autointent.custom_types import LabelWithOOS, Split
1313
from autointent.schemas import Intent, Tag
1414

1515

1616
class Sample(TypedDict):
1717
"""
1818
Typed dictionary representing a dataset sample.
1919
20-
:param str utterance: The text of the utterance.
21-
:param LabelType | None label: The label associated with the utterance, or None if out-of-scope.
20+
:param utterance: The text of the utterance.
21+
:param label: The label associated with the utterance, or None if out-of-scope.
2222
"""
2323

2424
utterance: str
25-
label: LabelType | None
25+
label: LabelWithOOS
2626

2727

2828
class Dataset(dict[str, HFDataset]):
@@ -39,7 +39,7 @@ class Dataset(dict[str, HFDataset]):
3939

4040
def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: # noqa: ANN401
4141
"""
42-
Initialize the dataset and configure OOS split if applicable.
42+
Initialize the dataset.
4343
4444
:param args: Positional arguments to initialize the dataset.
4545
:param intents: List of intents associated with the dataset.
@@ -49,15 +49,6 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: #
4949

5050
self.intents = intents
5151

52-
self._encoded_labels = False
53-
54-
if self.multilabel:
55-
self._encode_labels()
56-
57-
oos_split = self._create_oos_split()
58-
if oos_split is not None:
59-
self[Split.OOS] = oos_split
60-
6152
@property
6253
def multilabel(self) -> bool:
6354
"""
@@ -125,7 +116,6 @@ def to_multilabel(self) -> "Dataset":
125116
"""
126117
for split_name, split in self.items():
127118
self[split_name] = split.map(self._to_multilabel)
128-
self._encode_labels()
129119
return self
130120

131121
def to_dict(self) -> dict[str, list[dict[str, Any]]]:
@@ -144,7 +134,10 @@ def to_json(self, filepath: str | Path) -> None:
144134
145135
:param filepath: The path to the file where the JSON data will be saved.
146136
"""
147-
with Path(filepath).open("w") as file:
137+
path = Path(filepath)
138+
if not path.parent.exists():
139+
path.parent.mkdir(parents=True)
140+
with path.open("w") as file:
148141
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)
149142

150143
def push_to_hub(self, repo_id: str, private: bool = False) -> None:
@@ -181,38 +174,15 @@ def get_n_classes(self, split: str) -> int:
181174
"""
182175
classes = set()
183176
for label in self[split][self.label_feature]:
184-
match (label, self._encoded_labels):
185-
case (int(), _):
177+
match label:
178+
case int():
186179
classes.add(label)
187-
case (list(), False):
188-
for label_ in label:
189-
classes.add(label_)
190-
case (list(), True):
180+
case list():
191181
for idx, label_ in enumerate(label):
192182
if label_:
193183
classes.add(idx)
194184
return len(classes)
195185

196-
def _encode_labels(self) -> "Dataset":
197-
"""
198-
Encode dataset labels into one-hot or multilabel format.
199-
200-
:return: Self, with labels encoded.
201-
"""
202-
for split_name, split in self.items():
203-
self[split_name] = split.map(self._encode_label)
204-
self._encoded_labels = True
205-
return self
206-
207-
def _is_oos(self, sample: Sample) -> bool:
208-
"""
209-
Check if a sample is out-of-scope.
210-
211-
:param sample: The sample to check.
212-
:return: True if the sample is out-of-scope, False otherwise.
213-
"""
214-
return sample["label"] is None
215-
216186
def _to_multilabel(self, sample: Sample) -> Sample:
217187
"""
218188
Convert a sample's label to multilabel format.
@@ -221,50 +191,7 @@ def _to_multilabel(self, sample: Sample) -> Sample:
221191
:return: Sample with label in multilabel format.
222192
"""
223193
if isinstance(sample["label"], int):
224-
sample["label"] = [sample["label"]]
225-
return sample
226-
227-
def _encode_label(self, sample: Sample) -> Sample:
228-
"""
229-
Encode a sample's label as a one-hot vector.
230-
231-
:param sample: The sample to encode.
232-
:return: Sample with encoded label.
233-
"""
234-
one_hot_label = [0] * self.n_classes
235-
match sample["label"]:
236-
case int():
237-
one_hot_label[sample["label"]] = 1
238-
case list():
239-
for idx in sample["label"]:
240-
one_hot_label[idx] = 1
241-
sample["label"] = one_hot_label
194+
ohe_vector = [0] * self.n_classes
195+
ohe_vector[sample["label"]] = 1
196+
sample["label"] = ohe_vector
242197
return sample
243-
244-
def _create_oos_split(self) -> HFDataset | None:
245-
"""
246-
Create an out-of-scope (OOS) split from the dataset.
247-
248-
:return: The OOS split if created, None otherwise.
249-
"""
250-
oos_splits = [split.filter(self._is_oos) for split in self.values()]
251-
oos_splits = [oos_split for oos_split in oos_splits if oos_split.num_rows]
252-
if oos_splits:
253-
for split_name, split in self.items():
254-
self[split_name] = split.filter(lambda sample: not self._is_oos(sample))
255-
return concatenate_datasets(oos_splits)
256-
return None
257-
258-
def _cast_label_feature(self) -> None:
259-
"""Cast the label feature of the dataset to the appropriate type."""
260-
for split_name, split in self.items():
261-
new_features = split.features.copy()
262-
if self.multilabel:
263-
new_features[self.label_feature] = Sequence(
264-
ClassLabel(num_classes=self.n_classes),
265-
)
266-
else:
267-
new_features[self.label_feature] = ClassLabel(
268-
num_classes=self.n_classes,
269-
)
270-
self[split_name] = split.cast(new_features)

autointent/_dataset/_validation.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,9 @@ def validate_dataset(self) -> "DatasetReader":
6666
]
6767
splits = [split for split in splits if split]
6868

69-
n_classes = [self._get_n_classes(split) for split in splits]
70-
if len(set(n_classes)) != 1:
71-
message = (
72-
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
73-
"Ensure all splits have the same number of classes."
74-
)
75-
raise ValueError(message)
76-
if not n_classes[0]:
77-
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
78-
raise ValueError(message)
69+
n_classes = self._validate_classes(splits)
7970

80-
self._validate_intents(n_classes[0])
71+
self._validate_intents(n_classes)
8172

8273
for split in splits:
8374
self._validate_split(split)
@@ -100,6 +91,20 @@ def _get_n_classes(self, split: list[Sample]) -> int:
10091
classes.add(label)
10192
return len(classes)
10293

94+
def _validate_classes(self, splits: list[list[Sample]]) -> int:
95+
"""Validate that each split has all classes."""
96+
n_classes = [self._get_n_classes(split) for split in splits]
97+
if len(set(n_classes)) != 1:
98+
message = (
99+
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
100+
"Ensure all splits have the same number of classes."
101+
)
102+
raise ValueError(message)
103+
if not n_classes[0]:
104+
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
105+
raise ValueError(message)
106+
return n_classes[0]
107+
103108
def _validate_intents(self, n_classes: int) -> "DatasetReader":
104109
"""
105110
Validate the intents by checking their IDs for sequential order.
@@ -132,7 +137,8 @@ def _validate_split(self, split: list[Sample]) -> "DatasetReader":
132137
intent_ids = {intent.id for intent in self.intents}
133138
for sample in split:
134139
message = (
135-
f"Sample with label {sample.label} references a non-existent intent ID. " f"Valid IDs are {intent_ids}."
140+
f"Sample with label {sample.label} and utterance {sample.utterance[:10]}... "
141+
f"references a non-existent intent ID. Valid IDs are {intent_ids}."
136142
)
137143
if isinstance(sample.label, int) and sample.label not in intent_ids:
138144
raise ValueError(message)

autointent/_pipeline/_pipeline.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33
import json
44
import logging
55
from pathlib import Path
6-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
77

88
import numpy as np
9-
import numpy.typing as npt
109
import yaml
1110

1211
from autointent import Context, Dataset
1312
from autointent.configs import CrossEncoderConfig, EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
14-
from autointent.custom_types import NodeType
13+
from autointent.custom_types import ListOfGenericLabels, NodeType
1514
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1615
from autointent.nodes import InferenceNode, NodeOptimizer
1716
from autointent.utils import load_default_search_space, load_search_space
1817

1918
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
2019

20+
if TYPE_CHECKING:
21+
from autointent.modules.abc import DecisionModule, ScoringModule
22+
2123

2224
class Pipeline:
2325
"""Pipeline optimizer class."""
@@ -185,7 +187,7 @@ def load(cls, path: str | Path) -> "Pipeline":
185187
inference_dict_config = yaml.safe_load(file)
186188
return cls.from_dict_config(inference_dict_config["nodes_configs"])
187189

188-
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
190+
def predict(self, utterances: list[str]) -> ListOfGenericLabels:
189191
"""
190192
Predict the labels for the utterances.
191193
@@ -196,8 +198,11 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
196198
msg = "Pipeline in optimization mode cannot perform inference"
197199
raise RuntimeError(msg)
198200

199-
scores = self.nodes[NodeType.scoring].module.predict(utterances) # type: ignore[union-attr]
200-
return self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr]
201+
scoring_module: ScoringModule = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
202+
decision_module: DecisionModule = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]
203+
204+
scores = scoring_module.predict(utterances)
205+
return decision_module.predict(scores)
201206

202207
def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
203208
"""
@@ -211,7 +216,7 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
211216
raise RuntimeError(msg)
212217

213218
scores, scores_metadata = self.nodes[NodeType.scoring].module.predict_with_metadata(utterances) # type: ignore[union-attr]
214-
predictions = self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr]
219+
predictions = self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr,arg-type]
215220
regexp_predictions, regexp_predictions_metadata = None, None
216221
if NodeType.regexp in self.nodes:
217222
regexp_predictions, regexp_predictions_metadata = self.nodes[NodeType.regexp].module.predict_with_metadata( # type: ignore[union-attr]

autointent/_pipeline/_schemas.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
from pydantic import BaseModel
44

5-
from autointent.custom_types import LabelType
5+
from autointent.custom_types import LabelWithOOS, ListOfLabels, ListOfLabelsWithOOS
66

77

88
class InferencePipelineUtteranceOutput(BaseModel):
99
"""Output of the inference pipeline for a single utterance."""
1010

1111
utterance: str
12-
prediction: LabelType
13-
regexp_prediction: LabelType | None
12+
prediction: LabelWithOOS
13+
regexp_prediction: LabelWithOOS
1414
regexp_prediction_metadata: Any
1515
score: list[float]
1616
score_metadata: Any
@@ -19,6 +19,6 @@ class InferencePipelineUtteranceOutput(BaseModel):
1919
class InferencePipelineOutput(BaseModel):
2020
"""Output of the inference pipeline."""
2121

22-
predictions: list[LabelType]
23-
regexp_predictions: list[LabelType] | None = None
22+
predictions: ListOfLabelsWithOOS
23+
regexp_predictions: ListOfLabels | None = None
2424
utterances: list[InferencePipelineUtteranceOutput] | None = None

autointent/_ranker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sklearn.linear_model import LogisticRegressionCV
1919
from torch import nn
2020

21-
from autointent.custom_types import LabelType
21+
from autointent.custom_types import ListOfLabels
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -158,7 +158,7 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr
158158
self._activations_list.clear()
159159
return res # type: ignore[no-any-return]
160160

161-
def _fit(self, pairs: list[tuple[str, str]], labels: list[LabelType]) -> None:
161+
def _fit(self, pairs: list[tuple[str, str]], labels: ListOfLabels) -> None:
162162
"""
163163
Train the logistic regression model on cross-encoder features.
164164
@@ -181,7 +181,7 @@ def _fit(self, pairs: list[tuple[str, str]], labels: list[LabelType]) -> None:
181181

182182
self._clf = clf
183183

184-
def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
184+
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
185185
"""
186186
Construct training samples and train the logistic regression classifier.
187187

0 commit comments

Comments
 (0)