Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
e7a724a
remove oos utilities from everywhere
voorhs Jan 20, 2025
8c5522e
stage progress
voorhs Jan 21, 2025
a7276c0
DataHandler: add support of splits with oos data
voorhs Jan 21, 2025
d4b4285
fix codestyle
voorhs Jan 21, 2025
efa2d99
fix wrong type annotation
voorhs Jan 21, 2025
f3ee03c
try to add proper multilabel case handling to stratified splitter
voorhs Jan 21, 2025
f73d143
add checking if loaded data already one hot encoded
voorhs Jan 21, 2025
c3b01cc
stage progress on getting rid of handling anything except from ohe la…
voorhs Jan 21, 2025
7b23667
continue
voorhs Jan 21, 2025
7a7f8e3
continue
voorhs Jan 21, 2025
ec0a7e2
update stratifyer
voorhs Jan 21, 2025
6d653a0
minor bug fix
voorhs Jan 22, 2025
f972eb3
fix typing
voorhs Jan 22, 2025
edac994
update test data
voorhs Jan 22, 2025
74d0cdd
update data handler a liitle bit
voorhs Jan 22, 2025
c9b340f
update test_nli_transformer
voorhs Jan 22, 2025
e91f865
bug fix in test data
voorhs Jan 22, 2025
7bf9e89
add oos, multilabel and inputs validation to decision modules
voorhs Jan 22, 2025
ad16d9f
fix codestyle
voorhs Jan 22, 2025
33ee01d
minor bug fix
voorhs Jan 22, 2025
b8829a2
add oos handling to metrics
voorhs Jan 22, 2025
ce53bd6
bug fix and update callback test
voorhs Jan 22, 2025
bb6db2f
update data_handler test
voorhs Jan 22, 2025
29abf14
update test for stratification
voorhs Jan 22, 2025
729fdd2
update description generation utility and corresponding tests
voorhs Jan 22, 2025
0e3eeee
bug fix in test
voorhs Jan 22, 2025
8d1282f
add test for oos handling in metrics functions
voorhs Jan 22, 2025
317c9bc
fix oos handling in metrics
voorhs Jan 22, 2025
23842e8
forgot to commit it earlier
voorhs Jan 22, 2025
c6628ed
minor refactoring of knn
voorhs Jan 22, 2025
77e9224
fix and update tests for decision modules
voorhs Jan 22, 2025
bfc0993
add validation for supporting multi-class problem
voorhs Jan 22, 2025
0bbc544
update tests for scoring modules
voorhs Jan 22, 2025
ad6da70
update how data_handler reads intent descriptions
voorhs Jan 22, 2025
59fe010
fix adaptive decision and add test on loading and dumping
voorhs Jan 22, 2025
aa75311
fix decision roc_auc and how labels are restores during auto-configur…
voorhs Jan 22, 2025
20acd67
fix some metric
voorhs Jan 22, 2025
722eab0
minor bug fix and update test for inference
voorhs Jan 22, 2025
48d518f
fix codestyle
voorhs Jan 22, 2025
5d584e8
stage progress on type fixing
voorhs Jan 22, 2025
8c092fd
finish fixing typing
voorhs Jan 23, 2025
357d769
pull actual code
voorhs Jan 23, 2025
234c3e0
update test for sklearn
voorhs Jan 23, 2025
4276715
fix user guides
voorhs Jan 23, 2025
b79716a
fix advanced user guide on datasets
voorhs Jan 23, 2025
6d61c5a
move data-related tests to a separate directory
voorhs Jan 23, 2025
94e012f
add oos handling test
voorhs Jan 23, 2025
97c5cd7
fix codestyle
voorhs Jan 23, 2025
2c23415
update doctests for decision modules
voorhs Jan 23, 2025
b71d341
remove clinc script
voorhs Jan 23, 2025
1424900
move exceptions to a separate submodule
voorhs Jan 23, 2025
e90dd08
fix imports
voorhs Jan 23, 2025
72f3f88
fix codestyle
voorhs Jan 23, 2025
662f812
remove unnecessary comments
voorhs Jan 23, 2025
977a4cf
fix tests for threshold and tunable modules
voorhs Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 16 additions & 89 deletions autointent/_dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
from pathlib import Path
from typing import Any, TypedDict

from datasets import ClassLabel, Sequence, concatenate_datasets, get_dataset_config_names, load_dataset
from datasets import Dataset as HFDataset
from datasets import Sequence, get_dataset_config_names, load_dataset

from autointent.custom_types import LabelType, Split
from autointent.custom_types import LabelWithOOS, Split
from autointent.schemas import Intent, Tag


class Sample(TypedDict):
"""
Typed dictionary representing a dataset sample.

:param str utterance: The text of the utterance.
:param LabelType | None label: The label associated with the utterance, or None if out-of-scope.
:param utterance: The text of the utterance.
:param label: The label associated with the utterance, or None if out-of-scope.
"""

utterance: str
label: LabelType | None
label: LabelWithOOS


class Dataset(dict[str, HFDataset]):
Expand All @@ -39,7 +39,7 @@ class Dataset(dict[str, HFDataset]):

def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: # noqa: ANN401
"""
Initialize the dataset and configure OOS split if applicable.
Initialize the dataset.

:param args: Positional arguments to initialize the dataset.
:param intents: List of intents associated with the dataset.
Expand All @@ -49,15 +49,6 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: #

self.intents = intents

self._encoded_labels = False

if self.multilabel:
self._encode_labels()

oos_split = self._create_oos_split()
if oos_split is not None:
self[Split.OOS] = oos_split

@property
def multilabel(self) -> bool:
"""
Expand Down Expand Up @@ -125,7 +116,6 @@ def to_multilabel(self) -> "Dataset":
"""
for split_name, split in self.items():
self[split_name] = split.map(self._to_multilabel)
self._encode_labels()
return self

def to_dict(self) -> dict[str, list[dict[str, Any]]]:
Expand All @@ -144,7 +134,10 @@ def to_json(self, filepath: str | Path) -> None:

:param filepath: The path to the file where the JSON data will be saved.
"""
with Path(filepath).open("w") as file:
path = Path(filepath)
if not path.parent.exists():
path.parent.mkdir(parents=True)
with path.open("w") as file:
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)

def push_to_hub(self, repo_id: str, private: bool = False) -> None:
Expand Down Expand Up @@ -181,38 +174,15 @@ def get_n_classes(self, split: str) -> int:
"""
classes = set()
for label in self[split][self.label_feature]:
match (label, self._encoded_labels):
case (int(), _):
match label:
case int():
classes.add(label)
case (list(), False):
for label_ in label:
classes.add(label_)
case (list(), True):
case list():
for idx, label_ in enumerate(label):
if label_:
classes.add(idx)
return len(classes)

def _encode_labels(self) -> "Dataset":
"""
Encode dataset labels into one-hot or multilabel format.

:return: Self, with labels encoded.
"""
for split_name, split in self.items():
self[split_name] = split.map(self._encode_label)
self._encoded_labels = True
return self

def _is_oos(self, sample: Sample) -> bool:
"""
Check if a sample is out-of-scope.

:param sample: The sample to check.
:return: True if the sample is out-of-scope, False otherwise.
"""
return sample["label"] is None

def _to_multilabel(self, sample: Sample) -> Sample:
"""
Convert a sample's label to multilabel format.
Expand All @@ -221,50 +191,7 @@ def _to_multilabel(self, sample: Sample) -> Sample:
:return: Sample with label in multilabel format.
"""
if isinstance(sample["label"], int):
sample["label"] = [sample["label"]]
return sample

def _encode_label(self, sample: Sample) -> Sample:
"""
Encode a sample's label as a one-hot vector.

:param sample: The sample to encode.
:return: Sample with encoded label.
"""
one_hot_label = [0] * self.n_classes
match sample["label"]:
case int():
one_hot_label[sample["label"]] = 1
case list():
for idx in sample["label"]:
one_hot_label[idx] = 1
sample["label"] = one_hot_label
ohe_vector = [0] * self.n_classes
ohe_vector[sample["label"]] = 1
sample["label"] = ohe_vector
return sample

def _create_oos_split(self) -> HFDataset | None:
"""
Create an out-of-scope (OOS) split from the dataset.

:return: The OOS split if created, None otherwise.
"""
oos_splits = [split.filter(self._is_oos) for split in self.values()]
oos_splits = [oos_split for oos_split in oos_splits if oos_split.num_rows]
if oos_splits:
for split_name, split in self.items():
self[split_name] = split.filter(lambda sample: not self._is_oos(sample))
return concatenate_datasets(oos_splits)
return None

def _cast_label_feature(self) -> None:
"""Cast the label feature of the dataset to the appropriate type."""
for split_name, split in self.items():
new_features = split.features.copy()
if self.multilabel:
new_features[self.label_feature] = Sequence(
ClassLabel(num_classes=self.n_classes),
)
else:
new_features[self.label_feature] = ClassLabel(
num_classes=self.n_classes,
)
self[split_name] = split.cast(new_features)
30 changes: 18 additions & 12 deletions autointent/_dataset/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,9 @@ def validate_dataset(self) -> "DatasetReader":
]
splits = [split for split in splits if split]

n_classes = [self._get_n_classes(split) for split in splits]
if len(set(n_classes)) != 1:
message = (
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
"Ensure all splits have the same number of classes."
)
raise ValueError(message)
if not n_classes[0]:
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
raise ValueError(message)
n_classes = self._validate_classes(splits)

self._validate_intents(n_classes[0])
self._validate_intents(n_classes)

for split in splits:
self._validate_split(split)
Expand All @@ -100,6 +91,20 @@ def _get_n_classes(self, split: list[Sample]) -> int:
classes.add(label)
return len(classes)

def _validate_classes(self, splits: list[list[Sample]]) -> int:
"""Validate that each split has all classes."""
n_classes = [self._get_n_classes(split) for split in splits]
if len(set(n_classes)) != 1:
message = (
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
"Ensure all splits have the same number of classes."
)
raise ValueError(message)
if not n_classes[0]:
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
raise ValueError(message)
return n_classes[0]

def _validate_intents(self, n_classes: int) -> "DatasetReader":
"""
Validate the intents by checking their IDs for sequential order.
Expand Down Expand Up @@ -132,7 +137,8 @@ def _validate_split(self, split: list[Sample]) -> "DatasetReader":
intent_ids = {intent.id for intent in self.intents}
for sample in split:
message = (
f"Sample with label {sample.label} references a non-existent intent ID. " f"Valid IDs are {intent_ids}."
f"Sample with label {sample.label} and utterance {sample.utterance[:10]}... "
f"references a non-existent intent ID. Valid IDs are {intent_ids}."
)
if isinstance(sample.label, int) and sample.label not in intent_ids:
raise ValueError(message)
Expand Down
19 changes: 12 additions & 7 deletions autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
import json
import logging
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import numpy.typing as npt
import yaml

from autointent import Context, Dataset
from autointent.configs import CrossEncoderConfig, EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
from autointent.custom_types import NodeType
from autointent.custom_types import ListOfGenericLabels, NodeType
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
from autointent.nodes import InferenceNode, NodeOptimizer
from autointent.utils import load_default_search_space, load_search_space

from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput

if TYPE_CHECKING:
from autointent.modules.abc import DecisionModule, ScoringModule


class Pipeline:
"""Pipeline optimizer class."""
Expand Down Expand Up @@ -185,7 +187,7 @@ def load(cls, path: str | Path) -> "Pipeline":
inference_dict_config = yaml.safe_load(file)
return cls.from_dict_config(inference_dict_config["nodes_configs"])

def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
def predict(self, utterances: list[str]) -> ListOfGenericLabels:
"""
Predict the labels for the utterances.

Expand All @@ -196,8 +198,11 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
msg = "Pipeline in optimization mode cannot perform inference"
raise RuntimeError(msg)

scores = self.nodes[NodeType.scoring].module.predict(utterances) # type: ignore[union-attr]
return self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr]
scoring_module: ScoringModule = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
decision_module: DecisionModule = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]

scores = scoring_module.predict(utterances)
return decision_module.predict(scores)

def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
"""
Expand All @@ -211,7 +216,7 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
raise RuntimeError(msg)

scores, scores_metadata = self.nodes[NodeType.scoring].module.predict_with_metadata(utterances) # type: ignore[union-attr]
predictions = self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr]
predictions = self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr,arg-type]
regexp_predictions, regexp_predictions_metadata = None, None
if NodeType.regexp in self.nodes:
regexp_predictions, regexp_predictions_metadata = self.nodes[NodeType.regexp].module.predict_with_metadata( # type: ignore[union-attr]
Expand Down
10 changes: 5 additions & 5 deletions autointent/_pipeline/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from pydantic import BaseModel

from autointent.custom_types import LabelType
from autointent.custom_types import LabelWithOOS, ListOfLabels, ListOfLabelsWithOOS


class InferencePipelineUtteranceOutput(BaseModel):
"""Output of the inference pipeline for a single utterance."""

utterance: str
prediction: LabelType
regexp_prediction: LabelType | None
prediction: LabelWithOOS
regexp_prediction: LabelWithOOS
regexp_prediction_metadata: Any
score: list[float]
score_metadata: Any
Expand All @@ -19,6 +19,6 @@ class InferencePipelineUtteranceOutput(BaseModel):
class InferencePipelineOutput(BaseModel):
"""Output of the inference pipeline."""

predictions: list[LabelType]
regexp_predictions: list[LabelType] | None = None
predictions: ListOfLabelsWithOOS
regexp_predictions: ListOfLabels | None = None
utterances: list[InferencePipelineUtteranceOutput] | None = None
6 changes: 3 additions & 3 deletions autointent/_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sklearn.linear_model import LogisticRegressionCV
from torch import nn

from autointent.custom_types import LabelType
from autointent.custom_types import ListOfLabels

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -158,7 +158,7 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr
self._activations_list.clear()
return res # type: ignore[no-any-return]

def _fit(self, pairs: list[tuple[str, str]], labels: list[LabelType]) -> None:
def _fit(self, pairs: list[tuple[str, str]], labels: ListOfLabels) -> None:
"""
Train the logistic regression model on cross-encoder features.

Expand All @@ -181,7 +181,7 @@ def _fit(self, pairs: list[tuple[str, str]], labels: list[LabelType]) -> None:

self._clf = clf

def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
"""
Construct training samples and train the logistic regression classifier.

Expand Down
Loading
Loading