Skip to content

Commit 9f6613b

Browse files
committed
report multiple scores
1 parent 071b38f commit 9f6613b

File tree

16 files changed

+201
-54
lines changed

16 files changed

+201
-54
lines changed

autointent/_callbacks/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
4242
:param kwargs: Data to log.
4343
"""
4444

45+
@abstractmethod
46+
def log_metrics(self, metrics: dict[str, Any]) -> None:
47+
"""
48+
Log metrics during training.
49+
50+
:param metrics: Metrics to log.
51+
"""
52+
4553
@abstractmethod
4654
def end_module(self) -> None:
4755
"""End a module."""

autointent/_callbacks/callback_handler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
4444
"""
4545
self.call_events("log_value", **kwargs)
4646

47+
def log_metrics(self, metrics: dict[str, Any]) -> None:
48+
"""
49+
Log metrics during training.
50+
51+
:param metrics: Metrics to log.
52+
"""
53+
self.call_events("log_metrics", metrics=metrics)
54+
4755
def end_module(self) -> None:
4856
"""End a module."""
4957
self.call_events("end_module")

autointent/_callbacks/tensorboard.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
7373
else:
7474
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
7575

76+
def log_metrics(self, metrics: dict[str, Any]) -> None:
77+
"""
78+
Log metrics during training.
79+
80+
:param metrics: Metrics to log.
81+
"""
82+
if self.module_writer is None:
83+
msg = "start_run must be called before log_value."
84+
raise RuntimeError(msg)
85+
86+
for key, value in metrics.items():
87+
if isinstance(value, int | float):
88+
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
89+
else:
90+
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
91+
7692
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
7793
"""
7894
Log final metrics.

autointent/_callbacks/wandb.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
5959
"""
6060
self.wandb.log(kwargs)
6161

62+
def log_metrics(self, metrics: dict[str, Any]) -> None:
63+
"""
64+
Log metrics during training.
65+
66+
:param metrics: Metrics to log.
67+
"""
68+
self.wandb.log(metrics)
69+
6270
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
6371
"""
6472
Log final metrics.

autointent/_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,4 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
179179
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
180180
np.save(embeddings_path, embeddings)
181181

182-
return embeddings # type: ignore[return-value]
182+
return embeddings

autointent/metrics/decision.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def decision_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> f
5050
:return: Score of the decision accuracy
5151
"""
5252
y_true_, y_pred_ = transform(y_true, y_pred)
53-
return np.mean(y_true_ == y_pred_) # type: ignore[no-any-return]
53+
return float(np.mean(y_true_ == y_pred_))
5454

5555

5656
def _decision_roc_auc_multiclass(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -83,7 +83,7 @@ def _decision_roc_auc_multiclass(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE
8383
binarized_pred = (y_pred_ == k).astype(int)
8484
roc_auc_scores.append(roc_auc_score(binarized_true, binarized_pred))
8585

86-
return np.mean(roc_auc_scores) # type: ignore[return-value]
86+
return float(np.mean(roc_auc_scores))
8787

8888

8989
def _decision_roc_auc_multilabel(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -98,7 +98,7 @@ def _decision_roc_auc_multilabel(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE
9898
:param y_pred: Predicted values of labels
9999
:return: Score of the decision accuracy
100100
"""
101-
return roc_auc_score(y_true, y_pred, average="macro") # type: ignore[no-any-return]
101+
return float(roc_auc_score(y_true, y_pred, average="macro"))
102102

103103

104104
def decision_roc_auc(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -135,7 +135,7 @@ def decision_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) ->
135135
:param y_pred: Predicted values of labels
136136
:return: Score of the decision precision
137137
"""
138-
return precision_score(y_true, y_pred, average="macro") # type: ignore[no-any-return]
138+
return float(precision_score(y_true, y_pred, average="macro"))
139139

140140

141141
def decision_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -150,7 +150,7 @@ def decision_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> flo
150150
:param y_pred: Predicted values of labels
151151
:return: Score of the decision recall
152152
"""
153-
return recall_score(y_true, y_pred, average="macro") # type: ignore[no-any-return]
153+
return float(recall_score(y_true, y_pred, average="macro"))
154154

155155

156156
def decision_f1(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -165,4 +165,4 @@ def decision_f1(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
165165
:param y_pred: Predicted values of labels
166166
:return: Score of the decision accuracy
167167
"""
168-
return f1_score(y_true, y_pred, average="macro") # type: ignore[no-any-return]
168+
return float(f1_score(y_true, y_pred, average="macro"))

autointent/metrics/retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def retrieval_ndcg(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE
630630
cur_idcg = _idcg(rel_scores, k)
631631
ndcg_scores.append(0.0 if cur_idcg == 0 else cur_dcg / cur_idcg)
632632

633-
return np.mean(ndcg_scores) # type: ignore[return-value]
633+
return float(np.mean(ndcg_scores))
634634

635635

636636
def retrieval_ndcg_intersecting(

autointent/metrics/scoring.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE,
7171
log_likelihood = labels_array * np.log(scores_array) + (1 - labels_array) * np.log(1 - scores_array)
7272
clipped_one = log_likelihood.clip(min=-100, max=100)
7373
res = clipped_one.mean()
74-
return res # type: ignore[no-any-return]
74+
return float(res)
7575

7676

7777
def scoring_roc_auc(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -96,7 +96,7 @@ def scoring_roc_auc(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> flo
9696
if labels_.ndim == 1:
9797
labels_ = (labels_[:, None] == np.arange(n_classes)[None, :]).astype(int)
9898

99-
return roc_auc_score(labels_, scores_, average="macro") # type: ignore[no-any-return]
99+
return float(roc_auc_score(labels_, scores_, average="macro"))
100100

101101

102102
def _calculate_decision_metric(func: DecisionMetricFn, labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -206,7 +206,7 @@ def scoring_hit_rate(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> fl
206206
top_ranked_labels = np.argmax(scores_, axis=1)
207207
is_in = labels_[np.arange(len(labels)), top_ranked_labels]
208208

209-
return np.mean(is_in) # type: ignore[no-any-return]
209+
return float(np.mean(is_in))
210210

211211

212212
def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -242,7 +242,7 @@ def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -
242242
labels_, scores_ = transform(labels, scores)
243243

244244
n_classes = scores_.shape[1]
245-
return 1 - (coverage_error(labels, scores) - 1) / (n_classes - 1) # type: ignore[no-any-return]
245+
return float(1 - (coverage_error(labels, scores) - 1) / (n_classes - 1))
246246

247247

248248
def scoring_neg_ranking_loss(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -258,7 +258,7 @@ def scoring_neg_ranking_loss(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYP
258258
:param scores: for each utterance, this list contains scores for each of `n_classes` classes
259259
:return: Score of the scoring metric
260260
"""
261-
return -label_ranking_loss(labels, scores) # type: ignore[no-any-return]
261+
return float(-label_ranking_loss(labels, scores))
262262

263263

264264
def scoring_map(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -274,4 +274,4 @@ def scoring_map(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
274274
:param scores: for each sample, this list contains scores for each of `n_classes` classes
275275
:return: mean average precision score
276276
"""
277-
return label_ranking_average_precision_score(labels, scores) # type: ignore[no-any-return]
277+
return float(label_ranking_average_precision_score(labels, scores))

autointent/modules/abc/_base.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from autointent.context import Context
99
from autointent.context.optimization_info import Artifact
1010
from autointent.custom_types import BaseMetadataDict
11-
from autointent.metrics import METRIC_FN
1211

1312

1413
class Module(ABC):
@@ -33,14 +32,15 @@ def score(
3332
self,
3433
context: Context,
3534
split: Literal["validation", "test"],
36-
metric_fn: METRIC_FN,
37-
) -> float:
35+
main_metric: str,
36+
) -> dict[str, float | str]:
3837
"""
3938
Calculate metric on test set and return metric value.
4039
4140
:param context: Context to score
4241
:param split: Split to score on
43-
:param metric_fn: Metric function
42+
:param main_metric: Name of main metric for evaluation
43+
:return: Computed metrics value for the test set or error code of metrics
4444
"""
4545

4646
@abstractmethod
@@ -102,3 +102,20 @@ def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> "Module":
102102
def get_embedder_name(self) -> str | None:
103103
"""Experimental method."""
104104
return None
105+
106+
@staticmethod
107+
def score_metrics(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float | str]:
108+
"""
109+
Score metrics on the test set.
110+
111+
:param params: Params to score
112+
:param metrics_dict:
113+
:return:
114+
"""
115+
metrics = {}
116+
for metric_name, metric_fn in metrics_dict.items():
117+
try:
118+
metrics[metric_name] = metric_fn(*params)
119+
except Exception as e: # noqa: PERF203, BLE001
120+
metrics[metric_name] = str(e)
121+
return metrics

autointent/modules/abc/_decision.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from autointent import Context
1010
from autointent.context.optimization_info import DecisionArtifact
1111
from autointent.custom_types import LabelType
12-
from autointent.metrics import DecisionMetricFn
12+
from autointent.metrics import PREDICTION_METRICS_MULTICLASS
1313
from autointent.modules.abc import Module
1414
from autointent.schemas import Tag
1515

@@ -44,19 +44,19 @@ def score(
4444
self,
4545
context: Context,
4646
split: Literal["validation", "test"],
47-
metric_fn: DecisionMetricFn,
48-
) -> float:
47+
main_metric: str,
48+
) -> dict[str, float | str]:
4949
"""
5050
Calculate metric on test set and return metric value.
5151
5252
:param context: Context to score
5353
:param split: Target split
54-
:param metric_fn: Metric function
55-
:return: Score
54+
:param main_metric: Name of main metric for evaluation
55+
:return: Computed metrics value for the test set or error code of metrics
5656
"""
5757
labels, scores = get_decision_evaluation_data(context, split)
5858
self._decisions = self.predict(scores)
59-
return metric_fn(labels, self._decisions)
59+
return self.score_metrics((labels, self._decisions), PREDICTION_METRICS_MULTICLASS)
6060

6161
def get_assets(self) -> DecisionArtifact:
6262
"""Return useful assets that represent intermediate data into context."""

0 commit comments

Comments
 (0)