Skip to content

Commit d9807cc

Browse files
authored
Feat/module inputs validation (#105)
* add validation to all modules * fix typing
1 parent 1824ce3 commit d9807cc

File tree

14 files changed

+42
-52
lines changed

14 files changed

+42
-52
lines changed

autointent/modules/abc/_base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _validate_multilabel(self, data_is_multilabel: bool) -> None:
133133
logger.error(msg)
134134
raise WrongClassificationError(msg)
135135

136-
def _validate_oos(self, data_contains_oos: bool) -> None:
136+
def _validate_oos(self, data_contains_oos: bool, raise_error: bool = True) -> None:
137137
if data_contains_oos != self.supports_oos:
138138
if self.supports_oos and not data_contains_oos:
139139
msg = (
@@ -143,10 +143,18 @@ def _validate_oos(self, data_contains_oos: bool) -> None:
143143
elif not self.supports_oos and data_contains_oos:
144144
msg = (
145145
f'"{self.name}" is NOT designed to handle OOS samples, but your data '
146-
"contain it. So, using this method reduces the power of classification."
146+
"contains it. So, using this method reduces the power of classification."
147147
)
148+
if raise_error:
149+
logger.error(msg)
150+
raise ValueError(msg)
148151
logger.warning(msg)
149152

153+
def _validate_task(self, labels: ListOfGenericLabels) -> None:
154+
self._n_classes, self._multilabel, self._oos = self._get_task_specs(labels)
155+
self._validate_multilabel(self._multilabel)
156+
self._validate_oos(self._oos)
157+
150158
@staticmethod
151159
def _get_task_specs(labels: ListOfGenericLabels) -> tuple[int, bool, bool]:
152160
"""

autointent/modules/abc/_decision.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,18 @@ def get_assets(self) -> DecisionArtifact:
6363
def clear_cache(self) -> None:
6464
"""Clear cache."""
6565

66-
def _validate_inputs(self, scores: npt.NDArray[Any], labels: ListOfGenericLabels) -> tuple[int, bool, bool]:
67-
"""
68-
Sanity check if labels and scores are valid to be a training data for decision module.
69-
70-
:param scores: training scores
71-
:param labels: training labels
72-
:return: number of classes, indicator if it's a multi-label task,
73-
indicator if data contains oos samples
74-
"""
75-
n_classes, multilabel, contains_oos_samples = super()._get_task_specs(labels)
76-
77-
if n_classes != scores.shape[1]:
66+
def _validate_task(self, scores: npt.NDArray[Any], labels: ListOfGenericLabels) -> None:
67+
self._n_classes, self._multilabel, self._oos = self._get_task_specs(labels)
68+
self._validate_multilabel(self._multilabel)
69+
self._validate_oos(self._oos, raise_error=False)
70+
if self._n_classes != scores.shape[1]:
7871
msg = (
7972
"There is a mismatch between provided labels and scores. "
80-
f"Labels contains {n_classes} classes, but scores contain "
73+
f"Labels contains {self._n_classes} classes, but scores contain "
8174
f"probabilities for {scores.shape[1]} classes."
8275
)
8376
raise ValueError(msg)
8477

85-
return n_classes, multilabel, contains_oos_samples
86-
8778

8879
def get_decision_evaluation_data(
8980
context: Context,

autointent/modules/abc/_scoring.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class ScoringModule(Module, ABC):
1919
using a scoring metric.
2020
"""
2121

22+
supports_oos = False
23+
2224
def score(
2325
self,
2426
context: Context,

autointent/modules/decision/_adaptive.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ def fit(
9797
"""
9898
self.tags = tags
9999

100-
self._n_classes, multilabel, contains_oos = self._validate_inputs(scores, labels)
101-
self._validate_multilabel(multilabel)
102-
self._validate_oos(contains_oos)
100+
self._validate_task(scores, labels)
103101

104102
metrics_list = []
105103
for r in self.search_space:

autointent/modules/decision/_argmax.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def fit(
7676
:param tags: Tags to fit
7777
:raises WrongClassificationError: If the classification is wrong.
7878
"""
79-
self._n_classes, multilabel, contains_oos = self._validate_inputs(scores, labels)
80-
self._validate_multilabel(multilabel)
81-
self._validate_oos(contains_oos)
79+
self._validate_task(scores, labels)
8280

8381
def predict(self, scores: npt.NDArray[Any]) -> list[int]:
8482
"""

autointent/modules/decision/_jinoos.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,7 @@ def fit(
8989
:param labels: Labels to fit
9090
:param tags: Tags to fit
9191
"""
92-
# TODO: use dev split instead of test split.
93-
self._n_classes, multilabel, contains_oos = self._validate_inputs(scores, labels)
94-
self._validate_multilabel(multilabel)
95-
self._validate_oos(contains_oos)
92+
self._validate_task(scores, labels)
9693

9794
pred_classes, best_scores = _predict(scores)
9895

autointent/modules/decision/_threshold.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ def fit(
113113
:param tags: Tags to fit
114114
"""
115115
self.tags = tags
116-
self._n_classes, self._multilabel, contains_oos = self._validate_inputs(scores, labels)
117-
self._validate_multilabel(self._multilabel)
118-
self._validate_oos(contains_oos)
116+
self._validate_task(scores, labels)
119117

120118
if not isinstance(self.thresh, float):
121119
if len(self.thresh) != self._n_classes:

autointent/modules/decision/_tunable.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ def fit(
119119
:param tags: Tags to fit
120120
"""
121121
self.tags = tags
122-
self._n_classes, self._multilabel, contains_oos = self._validate_inputs(scores, labels)
123-
self._validate_multilabel(self._multilabel)
124-
self._validate_oos(contains_oos)
122+
self._validate_task(scores, labels)
125123

126124
thresh_optimizer = ThreshOptimizer(
127125
n_classes=self._n_classes, multilabel=self._multilabel, n_trials=self.n_trials

autointent/modules/scoring/_description/description.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class DescriptionScorer(ScoringModule):
2929
_n_classes: int
3030
_multilabel: bool
3131
_description_vectors: NDArray[Any]
32+
supports_multiclass = True
33+
supports_multilabel = True
3234

3335
def __init__(
3436
self,
@@ -105,12 +107,7 @@ def fit(
105107
:param descriptions: List of intent descriptions.
106108
:raises ValueError: If descriptions contain None values or embeddings mismatch utterances.
107109
"""
108-
if isinstance(labels[0], list):
109-
self._n_classes = len(labels[0])
110-
self._multilabel = True
111-
else:
112-
self._n_classes = len(set(labels))
113-
self._multilabel = False
110+
self._validate_task(labels)
114111

115112
if any(description is None for description in descriptions):
116113
error_text = (

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class DNNCScorer(ScoringModule):
7070
_n_classes: int
7171
_vector_index: VectorIndex
7272
_cross_encoder: Ranker
73+
supports_multilabel = False
74+
supports_multiclass = True
7375

7476
def __init__( # noqa: PLR0913
7577
self,
@@ -155,7 +157,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
155157
:param labels: List of labels corresponding to the utterances.
156158
:raises ValueError: If the vector index mismatches the provided utterances.
157159
"""
158-
self._n_classes = len(set(labels))
160+
self._validate_task(labels)
159161

160162
self._vector_index = VectorIndex(
161163
self.embedder_name,

0 commit comments

Comments
 (0)