Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions autointent/modules/abc/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _validate_multilabel(self, data_is_multilabel: bool) -> None:
logger.error(msg)
raise WrongClassificationError(msg)

def _validate_oos(self, data_contains_oos: bool) -> None:
def _validate_oos(self, data_contains_oos: bool, raise_error: bool = True) -> None:
if data_contains_oos != self.supports_oos:
if self.supports_oos and not data_contains_oos:
msg = (
Expand All @@ -143,10 +143,18 @@ def _validate_oos(self, data_contains_oos: bool) -> None:
elif not self.supports_oos and data_contains_oos:
msg = (
f'"{self.name}" is NOT designed to handle OOS samples, but your data '
"contain it. So, using this method reduces the power of classification."
"contains it. So, using this method reduces the power of classification."
)
if raise_error:
logger.error(msg)
raise ValueError(msg)
logger.warning(msg)

def _validate_task(self, labels: ListOfGenericLabels) -> None:
self._n_classes, self._multilabel, self._oos = self._get_task_specs(labels)
self._validate_multilabel(self._multilabel)
self._validate_oos(self._oos)

@staticmethod
def _get_task_specs(labels: ListOfGenericLabels) -> tuple[int, bool, bool]:
"""
Expand Down
21 changes: 6 additions & 15 deletions autointent/modules/abc/_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,18 @@ def get_assets(self) -> DecisionArtifact:
def clear_cache(self) -> None:
"""Clear cache."""

def _validate_inputs(self, scores: npt.NDArray[Any], labels: ListOfGenericLabels) -> tuple[int, bool, bool]:
"""
Sanity check if labels and scores are valid to be a training data for decision module.

:param scores: training scores
:param labels: training labels
:return: number of classes, indicator if it's a multi-label task,
indicator if data contains oos samples
"""
n_classes, multilabel, contains_oos_samples = super()._get_task_specs(labels)

if n_classes != scores.shape[1]:
def _validate_task(self, scores: npt.NDArray[Any], labels: ListOfGenericLabels) -> None:
self._n_classes, self._multilabel, self._oos = self._get_task_specs(labels)
self._validate_multilabel(self._multilabel)
self._validate_oos(self._oos, raise_error=False)
if self._n_classes != scores.shape[1]:
msg = (
"There is a mismatch between provided labels and scores. "
f"Labels contains {n_classes} classes, but scores contain "
f"Labels contains {self._n_classes} classes, but scores contain "
f"probabilities for {scores.shape[1]} classes."
)
raise ValueError(msg)

return n_classes, multilabel, contains_oos_samples


def get_decision_evaluation_data(
context: Context,
Expand Down
2 changes: 2 additions & 0 deletions autointent/modules/abc/_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class ScoringModule(Module, ABC):
using a scoring metric.
"""

supports_oos = False

def score(
self,
context: Context,
Expand Down
4 changes: 1 addition & 3 deletions autointent/modules/decision/_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def fit(
"""
self.tags = tags

self._n_classes, multilabel, contains_oos = self._validate_inputs(scores, labels)
self._validate_multilabel(multilabel)
self._validate_oos(contains_oos)
self._validate_task(scores, labels)

metrics_list = []
for r in self.search_space:
Expand Down
4 changes: 1 addition & 3 deletions autointent/modules/decision/_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def fit(
:param tags: Tags to fit
:raises WrongClassificationError: If the classification is wrong.
"""
self._n_classes, multilabel, contains_oos = self._validate_inputs(scores, labels)
self._validate_multilabel(multilabel)
self._validate_oos(contains_oos)
self._validate_task(scores, labels)

def predict(self, scores: npt.NDArray[Any]) -> list[int]:
"""
Expand Down
5 changes: 1 addition & 4 deletions autointent/modules/decision/_jinoos.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ def fit(
:param labels: Labels to fit
:param tags: Tags to fit
"""
# TODO: use dev split instead of test split.
self._n_classes, multilabel, contains_oos = self._validate_inputs(scores, labels)
self._validate_multilabel(multilabel)
self._validate_oos(contains_oos)
self._validate_task(scores, labels)

pred_classes, best_scores = _predict(scores)

Expand Down
4 changes: 1 addition & 3 deletions autointent/modules/decision/_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def fit(
:param tags: Tags to fit
"""
self.tags = tags
self._n_classes, self._multilabel, contains_oos = self._validate_inputs(scores, labels)
self._validate_multilabel(self._multilabel)
self._validate_oos(contains_oos)
self._validate_task(scores, labels)

if not isinstance(self.thresh, float):
if len(self.thresh) != self._n_classes:
Expand Down
4 changes: 1 addition & 3 deletions autointent/modules/decision/_tunable.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ def fit(
:param tags: Tags to fit
"""
self.tags = tags
self._n_classes, self._multilabel, contains_oos = self._validate_inputs(scores, labels)
self._validate_multilabel(self._multilabel)
self._validate_oos(contains_oos)
self._validate_task(scores, labels)

thresh_optimizer = ThreshOptimizer(
n_classes=self._n_classes, multilabel=self._multilabel, n_trials=self.n_trials
Expand Down
9 changes: 3 additions & 6 deletions autointent/modules/scoring/_description/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class DescriptionScorer(ScoringModule):
_n_classes: int
_multilabel: bool
_description_vectors: NDArray[Any]
supports_multiclass = True
supports_multilabel = True

def __init__(
self,
Expand Down Expand Up @@ -105,12 +107,7 @@ def fit(
:param descriptions: List of intent descriptions.
:raises ValueError: If descriptions contain None values or embeddings mismatch utterances.
"""
if isinstance(labels[0], list):
self._n_classes = len(labels[0])
self._multilabel = True
else:
self._n_classes = len(set(labels))
self._multilabel = False
self._validate_task(labels)

if any(description is None for description in descriptions):
error_text = (
Expand Down
4 changes: 3 additions & 1 deletion autointent/modules/scoring/_dnnc/dnnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class DNNCScorer(ScoringModule):
_n_classes: int
_vector_index: VectorIndex
_cross_encoder: Ranker
supports_multilabel = False
supports_multiclass = True

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -155,7 +157,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
:param labels: List of labels corresponding to the utterances.
:raises ValueError: If the vector index mismatches the provided utterances.
"""
self._n_classes = len(set(labels))
self._validate_task(labels)

self._vector_index = VectorIndex(
self.embedder_name,
Expand Down
5 changes: 2 additions & 3 deletions autointent/modules/scoring/_knn/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class KNNScorer(ScoringModule):
_n_classes: int
_multilabel: bool
supports_multilabel = True
supports_oos = False
supports_multiclass = True

def __init__(
self,
Expand Down Expand Up @@ -132,8 +132,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
:param labels: List of labels corresponding to the utterances.
:raises ValueError: If the vector index mismatches the provided utterances.
"""
self._n_classes, self._multilabel, contains_oos = self._get_task_specs(labels)
self._validate_oos(contains_oos)
self._validate_task(labels)

self._vector_index = VectorIndex(
self.embedder_name,
Expand Down
4 changes: 3 additions & 1 deletion autointent/modules/scoring/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class LinearScorer(ScoringModule):
_multilabel: bool
_clf: LogisticRegressionCV | MultiOutputClassifier
_embedder: Embedder
supports_multiclass = True
supports_multilabel = True

def __init__(
self,
Expand Down Expand Up @@ -125,7 +127,7 @@ def fit(
:param labels: List of labels corresponding to the utterances.
:raises ValueError: If the vector index mismatches the provided utterances.
"""
self._multilabel = isinstance(labels[0], list)
self._validate_task(labels)

embedder = Embedder(
device=self.embedder_device,
Expand Down
8 changes: 3 additions & 5 deletions autointent/modules/scoring/_mlknn/mlknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class MLKnnScorer(ScoringModule):
_cond_prob_false: NDArray[Any]
_features: NDArray[Any]
_labels: NDArray[Any]
supports_multiclass = False
supports_multilabel = True

def __init__(
self,
Expand Down Expand Up @@ -135,11 +137,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
:raises TypeError: If the labels are not multi-label.
:raises ValueError: If the vector index mismatches the provided utterances.
"""
if not isinstance(labels[0], list):
msg = "mlknn scorer support only multilabel input"
raise TypeError(msg)

self._n_classes = len(labels[0])
self._validate_task(labels)

self._vector_index = VectorIndex(
self.embedder_name,
Expand Down
8 changes: 5 additions & 3 deletions autointent/modules/scoring/_sklearn/sklearn_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import Self

from autointent import Context, Embedder
from autointent.custom_types import LabelType
from autointent.custom_types import ListOfLabels
from autointent.modules.abc import ScoringModule

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,6 +38,8 @@ class SklearnScorer(ScoringModule):
"""

name = "sklearn"
supports_multilabel = True
supports_multiclass = True

def __init__(
self,
Expand Down Expand Up @@ -101,7 +103,7 @@ def from_context(
def fit(
self,
utterances: list[str],
labels: list[LabelType],
labels: ListOfLabels,
) -> None:
"""
Train the chosen sklearn classifier.
Expand All @@ -110,7 +112,7 @@ def fit(
:param labels: List of labels corresponding to the utterances.
:raises ValueError: If the vector index mismatches the provided utterances.
"""
self._multilabel = isinstance(labels[0], list)
self._validate_task(labels)

embedder = Embedder(
device=self.embedder_device,
Expand Down
Loading