Skip to content

Commit 23efe32

Browse files
authored
feat: added optional multiple metrics (#108)
* feat: added optional multiple metrics * fix: change metric to decision_metric * fix: fixed type * fix: typing * refactor: decision metric to target metric * fix: target metric * fix: fixed tests * fix: test_logreg * fix: fixed unit test * fix: fixed test * feat: update guides * feat: update guides
1 parent 36214e5 commit 23efe32

File tree

22 files changed

+96
-93
lines changed

22 files changed

+96
-93
lines changed
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# TODO: make up a better and more versatile config
22
- node_type: embedding
3-
metric: retrieval_hit_rate
3+
target_metric: retrieval_hit_rate
44
search_space:
55
- module_name: retrieval
66
k: [10]
77
embedder_name:
88
- avsolatorio/GIST-small-Embedding-v0
99
- infgrad/stella-base-en-v2
1010
- node_type: scoring
11-
metric: scoring_roc_auc
11+
target_metric: scoring_roc_auc
1212
search_space:
1313
- module_name: knn
1414
k: [1, 3, 5, 10]
@@ -20,8 +20,8 @@
2020
- cross-encoder/ms-marco-MiniLM-L-6-v2
2121
k: [1, 3, 5, 10]
2222
- node_type: decision
23-
metric: decision_accuracy
23+
target_metric: decision_accuracy
2424
search_space:
2525
- module_name: threshold
2626
thresh: [0.5]
27-
- module_name: argmax
27+
- module_name: argmax
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
# TODO: make up a better and more versatile config
22
- node_type: embedding
3-
metric: retrieval_hit_rate_intersecting
3+
target_metric: retrieval_hit_rate_intersecting
44
search_space:
55
- module_name: retrieval
66
k: [10]
77
embedder_name:
88
- deepvk/USER-bge-m3
99
- node_type: scoring
10-
metric: scoring_roc_auc
10+
target_metric: scoring_roc_auc
1111
search_space:
1212
- module_name: knn
1313
k: [3]
1414
weights: ["uniform", "distance", "closest"]
1515
- module_name: linear
1616
- node_type: decision
17-
metric: decision_accuracy
17+
target_metric: decision_accuracy
1818
search_space:
1919
- module_name: threshold
2020
thresh: [0.5]
21-
- module_name: adaptive
21+
- module_name: adaptive

autointent/modules/abc/_base.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@ def fit(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
3434
"""
3535

3636
@abstractmethod
37-
def score(
38-
self,
39-
context: Context,
40-
split: Literal["validation", "test"],
41-
) -> dict[str, float | str]:
37+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
4238
"""
4339
Calculate metric on test set and return metric value.
4440
@@ -110,7 +106,7 @@ def get_embedder_name(self) -> str | None:
110106
return None
111107

112108
@staticmethod
113-
def score_metrics(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float | str]:
109+
def score_metrics(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float]:
114110
"""
115111
Score metrics on the test set.
116112

autointent/modules/abc/_decision.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels:
4040
:param scores: Scores to predict
4141
"""
4242

43-
def score(
44-
self,
45-
context: Context,
46-
split: Literal["validation", "test"],
47-
) -> dict[str, float | str]:
43+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
4844
"""
4945
Calculate metric on test set and return metric value.
5046
@@ -54,7 +50,8 @@ def score(
5450
"""
5551
labels, scores = get_decision_evaluation_data(context, split)
5652
self._decisions = self.predict(scores)
57-
return self.score_metrics((labels, self._decisions), PREDICTION_METRICS_MULTICLASS)
53+
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
54+
return self.score_metrics((labels, self._decisions), chosen_metrics)
5855

5956
def get_assets(self) -> DecisionArtifact:
6057
"""Return useful assets that represent intermediate data into context."""

autointent/modules/abc/_scoring.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@ class ScoringModule(Module, ABC):
2121

2222
supports_oos = False
2323

24-
def score(
25-
self,
26-
context: Context,
27-
split: Literal["validation", "test"],
28-
) -> dict[str, float | str]:
24+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
2925
"""
3026
Evaluate the scorer on a test set and compute the specified metric.
3127
@@ -50,7 +46,8 @@ def score(
5046
self._test_scores = self.predict(context.data_handler.test_utterances())
5147

5248
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
53-
return self.score_metrics((labels, scores), metrics_dict)
49+
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
50+
return self.score_metrics((labels, scores), chosen_metrics)
5451

5552
def get_assets(self) -> ScorerArtifact:
5653
"""

autointent/modules/embedding/_logreg.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
129129

130130
self._classifier.fit(embeddings, labels)
131131

132-
def score(
133-
self,
134-
context: Context,
135-
split: Literal["validation", "test"],
136-
) -> dict[str, float | str]:
132+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
137133
"""
138134
Evaluate the embedding model using a specified metric function.
139135
@@ -153,7 +149,8 @@ def score(
153149

154150
probas = self.predict(utterances)
155151
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
156-
return self.score_metrics((labels, probas), metrics_dict)
152+
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
153+
return self.score_metrics((labels, probas), chosen_metrics)
157154

158155
def get_assets(self) -> RetrieverArtifact:
159156
"""

autointent/modules/embedding/_retrieval.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
109109
)
110110
self._vector_index.add(utterances, labels)
111111

112-
def score(
113-
self,
114-
context: Context,
115-
split: Literal["validation", "test"],
116-
) -> dict[str, float | str]:
112+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
117113
"""
118114
Evaluate the embedding model using a specified metric function.
119115
@@ -133,7 +129,8 @@ def score(
133129
predictions, _, _ = self._vector_index.query(utterances, self.k)
134130

135131
metrics_dict = RETRIEVAL_METRICS_MULTILABEL if context.is_multilabel() else RETRIEVAL_METRICS_MULTICLASS
136-
return self.score_metrics((labels, predictions), metrics_dict)
132+
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
133+
return self.score_metrics((labels, predictions), chosen_metrics)
137134

138135
def get_assets(self) -> RetrieverArtifact:
139136
"""

autointent/modules/regexp/_regexp.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,7 @@ def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str
108108
matches["partial_matches"].extend(intent_matches["partial_matches"])
109109
return list(prediction), matches
110110

111-
def score(
112-
self,
113-
context: Context,
114-
split: Literal["validation", "test"],
115-
) -> dict[str, float | str]:
111+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
116112
"""
117113
Calculate metric on test set and return metric value.
118114
@@ -128,7 +124,8 @@ def score(
128124
if assets["test_matches"] is None:
129125
msg = "no matches found"
130126
raise ValueError(msg)
131-
return self.score_metrics((context.data_handler.test_labels(), assets["test_matches"]), REGEXP_METRICS)
127+
chosen_metrics = {name: fn for name, fn in REGEXP_METRICS.items() if name in metrics}
128+
return self.score_metrics((context.data_handler.test_labels(), assets["test_matches"]), chosen_metrics)
132129

133130
def clear_cache(self) -> None:
134131
"""Clear cache."""

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
class NodeOptimizer:
2020
"""Node optimizer class."""
2121

22-
def __init__(self, node_type: NodeType, search_space: list[dict[str, Any]], metric: str) -> None:
22+
def __init__(
23+
self,
24+
node_type: NodeType,
25+
search_space: list[dict[str, Any]],
26+
target_metric: str,
27+
metrics: list[str] | None = None,
28+
) -> None:
2329
"""
2430
Initialize the node optimizer.
2531
@@ -29,7 +35,12 @@ def __init__(self, node_type: NodeType, search_space: list[dict[str, Any]], metr
2935
"""
3036
self.node_type = node_type
3137
self.node_info = NODES_INFO[node_type]
32-
self.metric_name = metric
38+
self.decision_metric_name = target_metric
39+
40+
self.metrics = metrics if metrics is not None else []
41+
if self.decision_metric_name not in self.metrics:
42+
self.metrics.append(self.decision_metric_name)
43+
3344
self.modules_search_spaces = search_space # TODO search space validation
3445
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem
3546

@@ -61,14 +72,10 @@ def fit(self, context: Context) -> None:
6172
self.module_fit(module, context)
6273

6374
self._logger.debug("scoring %s module...", module_name)
64-
metrics = module.score(context, "validation")
65-
metric_value = metrics[self.metric_name]
66-
67-
# some metrics can produce error. When main metric produces error raise it.
68-
if isinstance(metric_value, str):
69-
raise Exception(metric_value) # noqa: TRY004, TRY002
75+
metrics_score = module.score(context, "validation", self.metrics)
76+
metric_value = metrics_score[self.decision_metric_name]
7077

71-
context.callback_handler.log_metrics(metrics)
78+
context.callback_handler.log_metrics(metrics_score)
7279
context.callback_handler.end_module()
7380

7481
dump_dir = context.get_dump_dir()
@@ -84,7 +91,7 @@ def fit(self, context: Context) -> None:
8491
module_name,
8592
module_kwargs,
8693
metric_value,
87-
self.metric_name,
94+
self.decision_metric_name,
8895
module.get_assets(), # retriever name / scores / predictions
8996
module_dump_dir,
9097
module=module if not context.is_ram_to_clear() else None,
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
- node_type: embedding
2-
metric: retrieval_hit_rate
2+
target_metric: retrieval_hit_rate
33
search_space:
44
- module_name: retrieval
55
k: [10]
66
embedder_name:
77
- sentence-transformers/all-MiniLM-L6-v2
88
- node_type: scoring
9-
metric: scoring_roc_auc
9+
target_metric: scoring_roc_auc
1010
search_space:
1111
- module_name: description
1212
temperature: [1.0, 0.5, 0.1, 0.05]
1313
- node_type: decision
14-
metric: decision_accuracy
14+
target_metric: decision_accuracy
1515
search_space:
1616
- module_name: argmax

0 commit comments

Comments
 (0)