Skip to content

Commit 5f0c565

Browse files
committed
fix typing
1 parent 2b4535e commit 5f0c565

File tree

4 files changed

+29
-29
lines changed

4 files changed

+29
-29
lines changed

autointent/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090

9191
PREDICTION_METRICS_MULTILABEL = PREDICTION_METRICS_MULTICLASS
9292

93-
REGEXP_METRICS = _funcs_to_dict(regex_partial_accuracy, regex_partial_precision)
93+
REGEX_METRICS = _funcs_to_dict(regex_partial_accuracy, regex_partial_precision)
9494

9595
METRIC_FN = DecisionMetricFn | RegexMetricFn | RetrievalMetricFn | ScoringMetricFn
9696

autointent/modules/abc/_base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,3 @@ def _get_task_specs(labels: ListOfGenericLabels) -> tuple[int, bool, bool]:
202202
multilabel = isinstance(in_domain_label, list)
203203
n_classes = len(in_domain_label) if multilabel else len(set(labels).difference([None])) # type: ignore[arg-type]
204204
return n_classes, multilabel, contains_oos_samples
205-
206-
@abstractmethod
207-
def get_train_data(self, context: Context) -> Any: # noqa: ANN401
208-
...

autointent/modules/regex/_simple.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Module for regular expressions based intent detection."""
22

33
import re
4-
from typing import Any, Literal, TypedDict
4+
from typing import Any, TypedDict
55

66
from autointent import Context
77
from autointent.context.data_handler._data_handler import RegexPatterns
88
from autointent.context.optimization_info import Artifact
99
from autointent.custom_types import LabelType
10-
from autointent.metrics import REGEXP_METRICS
10+
from autointent.metrics import REGEX_METRICS
1111
from autointent.modules.abc import RegexModule
1212
from autointent.schemas import Intent
1313

@@ -33,23 +33,19 @@ def from_context(cls, context: Context) -> "Regex":
3333
"""Initialize from context."""
3434
return cls()
3535

36-
def get_train_data(self, context: Context) -> list[Intent]:
37-
return context.data_handler.dataset.intents
38-
39-
def fit(self, intents: list[dict[str, Any]]) -> None:
36+
def fit(self, intents: list[Intent]) -> None:
4037
"""
4138
Fit the model.
4239
4340
:param intents: Intents to fit
4441
"""
45-
intents_parsed = [Intent(**dct) for dct in intents]
4642
self.regex_patterns = [
4743
RegexPatterns(
4844
id=intent.id,
4945
regex_full_match=intent.regex_full_match,
5046
regex_partial_match=intent.regex_partial_match,
5147
)
52-
for intent in intents_parsed
48+
for intent in intents
5349
]
5450
self._compile_regex_patterns()
5551

@@ -109,24 +105,32 @@ def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str
109105
matches["partial_matches"].extend(intent_matches["partial_matches"])
110106
return list(prediction), matches
111107

112-
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
108+
def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
109+
self.fit(context.data_handler.dataset.intents)
110+
111+
val_utterances = context.data_handler.validation_utterances(0)
112+
val_labels = context.data_handler.validation_labels(0)
113+
114+
pred_labels = self.predict(val_utterances)
115+
116+
chosen_metrics = {name: fn for name, fn in REGEX_METRICS.items() if name in metrics}
117+
return self.score_metrics_ho((val_labels, pred_labels), chosen_metrics)
118+
119+
def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
113120
"""
114-
Calculate metric on test set and return metric value.
121+
Evaluate the scorer on a test set and compute the specified metric.
115122
116-
:param context: Context to score
117-
:param split: Split to score on
123+
:param context: Context containing test set and other data.
124+
:param split: Target split
118125
:return: Computed metrics value for the test set or error code of metrics
119126
"""
120-
# TODO add parameter to a whole pipeline (or just to regex module):
121-
# whether or not to omit utterances on next stages if they were detected with regex module
122-
assets = {
123-
"test_matches": list(self.predict(context.data_handler.test_utterances())),
124-
}
125-
if assets["test_matches"] is None:
126-
msg = "no matches found"
127-
raise ValueError(msg)
128-
chosen_metrics = {name: fn for name, fn in REGEXP_METRICS.items() if name in metrics}
129-
return self.score_metrics((context.data_handler.test_labels(), assets["test_matches"]), chosen_metrics)
127+
chosen_metrics = {name: fn for name, fn in REGEX_METRICS.items() if name in metrics}
128+
129+
metrics_calculated, _ = self.score_metrics_cv(
130+
chosen_metrics, context.data_handler.validation_iterator()
131+
)
132+
133+
return metrics_calculated
130134

131135
def clear_cache(self) -> None:
132136
"""Clear cache."""

autointent/nodes/info/_regex.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import ClassVar
55

66
from autointent.custom_types import NodeType
7-
from autointent.metrics import REGEXP_METRICS
7+
from autointent.metrics import REGEX_METRICS
88
from autointent.metrics.regex import RegexMetricFn
99
from autointent.modules.abc import RegexModule
1010
from autointent.modules.regex import Regex
@@ -15,7 +15,7 @@
1515
class RegexNodeInfo(NodeInfo):
1616
"""Regex node info."""
1717

18-
metrics_available: ClassVar[Mapping[str, RegexMetricFn]] = REGEXP_METRICS
18+
metrics_available: ClassVar[Mapping[str, RegexMetricFn]] = REGEX_METRICS
1919

2020
modules_available: ClassVar[Mapping[str, type[RegexModule]]] = {NodeType.regex: Regex}
2121

0 commit comments

Comments
 (0)