Skip to content

Commit 0b507ae

Browse files
authored
report multiple scores (#91)
* report multiple scores * try to fix * fix score test * remove main metric * remove main metric * fix * fix
1 parent 6231290 commit 0b507ae

File tree

15 files changed

+196
-58
lines changed

15 files changed

+196
-58
lines changed

autointent/_callbacks/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
4242
:param kwargs: Data to log.
4343
"""
4444

45+
@abstractmethod
46+
def log_metrics(self, metrics: dict[str, Any]) -> None:
47+
"""
48+
Log metrics during training.
49+
50+
:param metrics: Metrics to log.
51+
"""
52+
4553
@abstractmethod
4654
def end_module(self) -> None:
4755
"""End a module."""

autointent/_callbacks/callback_handler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
4444
"""
4545
self.call_events("log_value", **kwargs)
4646

47+
def log_metrics(self, metrics: dict[str, Any]) -> None:
48+
"""
49+
Log metrics during training.
50+
51+
:param metrics: Metrics to log.
52+
"""
53+
self.call_events("log_metrics", metrics=metrics)
54+
4755
def end_module(self) -> None:
4856
"""End a module."""
4957
self.call_events("end_module")

autointent/_callbacks/tensorboard.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
7373
else:
7474
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
7575

76+
def log_metrics(self, metrics: dict[str, Any]) -> None:
77+
"""
78+
Log metrics during training.
79+
80+
:param metrics: Metrics to log.
81+
"""
82+
if self.module_writer is None:
83+
msg = "start_run must be called before log_value."
84+
raise RuntimeError(msg)
85+
86+
for key, value in metrics.items():
87+
if isinstance(value, int | float):
88+
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
89+
else:
90+
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
91+
7692
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
7793
"""
7894
Log final metrics.

autointent/_callbacks/wandb.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
5959
"""
6060
self.wandb.log(kwargs)
6161

62+
def log_metrics(self, metrics: dict[str, Any]) -> None:
63+
"""
64+
Log metrics during training.
65+
66+
:param metrics: Metrics to log.
67+
"""
68+
self.wandb.log(metrics)
69+
6270
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
6371
"""
6472
Log final metrics.

autointent/metrics/decision.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def decision_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> f
5050
:return: Score of the decision accuracy
5151
"""
5252
y_true_, y_pred_ = transform(y_true, y_pred)
53-
return np.mean(y_true_ == y_pred_) # type: ignore[no-any-return]
53+
return float(np.mean(y_true_ == y_pred_))
5454

5555

5656
def _decision_roc_auc_multiclass(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -83,7 +83,7 @@ def _decision_roc_auc_multiclass(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE
8383
binarized_pred = (y_pred_ == k).astype(int)
8484
roc_auc_scores.append(roc_auc_score(binarized_true, binarized_pred))
8585

86-
return np.mean(roc_auc_scores) # type: ignore[return-value]
86+
return float(np.mean(roc_auc_scores))
8787

8888

8989
def _decision_roc_auc_multilabel(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -98,7 +98,7 @@ def _decision_roc_auc_multilabel(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE
9898
:param y_pred: Predicted values of labels
9999
:return: Score of the decision accuracy
100100
"""
101-
return roc_auc_score(y_true, y_pred, average="macro") # type: ignore[no-any-return]
101+
return float(roc_auc_score(y_true, y_pred, average="macro"))
102102

103103

104104
def decision_roc_auc(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -135,7 +135,7 @@ def decision_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) ->
135135
:param y_pred: Predicted values of labels
136136
:return: Score of the decision precision
137137
"""
138-
return precision_score(y_true, y_pred, average="macro") # type: ignore[no-any-return]
138+
return float(precision_score(y_true, y_pred, average="macro"))
139139

140140

141141
def decision_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -150,7 +150,7 @@ def decision_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> flo
150150
:param y_pred: Predicted values of labels
151151
:return: Score of the decision recall
152152
"""
153-
return recall_score(y_true, y_pred, average="macro") # type: ignore[no-any-return]
153+
return float(recall_score(y_true, y_pred, average="macro"))
154154

155155

156156
def decision_f1(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
@@ -165,4 +165,4 @@ def decision_f1(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
165165
:param y_pred: Predicted values of labels
166166
:return: Score of the decision accuracy
167167
"""
168-
return f1_score(y_true, y_pred, average="macro") # type: ignore[no-any-return]
168+
return float(f1_score(y_true, y_pred, average="macro"))

autointent/metrics/retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def retrieval_ndcg(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE
630630
cur_idcg = _idcg(rel_scores, k)
631631
ndcg_scores.append(0.0 if cur_idcg == 0 else cur_dcg / cur_idcg)
632632

633-
return np.mean(ndcg_scores) # type: ignore[return-value]
633+
return float(np.mean(ndcg_scores))
634634

635635

636636
def retrieval_ndcg_intersecting(

autointent/metrics/scoring.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE,
7171
log_likelihood = labels_array * np.log(scores_array) + (1 - labels_array) * np.log(1 - scores_array)
7272
clipped_one = log_likelihood.clip(min=-100, max=100)
7373
res = clipped_one.mean()
74-
return res # type: ignore[no-any-return]
74+
# test produces different output
75+
return round(float(res), 6)
7576

7677

7778
def scoring_roc_auc(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -96,7 +97,7 @@ def scoring_roc_auc(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> flo
9697
if labels_.ndim == 1:
9798
labels_ = (labels_[:, None] == np.arange(n_classes)[None, :]).astype(int)
9899

99-
return roc_auc_score(labels_, scores_, average="macro") # type: ignore[no-any-return]
100+
return float(roc_auc_score(labels_, scores_, average="macro"))
100101

101102

102103
def _calculate_decision_metric(func: DecisionMetricFn, labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -206,7 +207,7 @@ def scoring_hit_rate(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> fl
206207
top_ranked_labels = np.argmax(scores_, axis=1)
207208
is_in = labels_[np.arange(len(labels)), top_ranked_labels]
208209

209-
return np.mean(is_in) # type: ignore[no-any-return]
210+
return float(np.mean(is_in))
210211

211212

212213
def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -242,7 +243,7 @@ def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -
242243
labels_, scores_ = transform(labels, scores)
243244

244245
n_classes = scores_.shape[1]
245-
return 1 - (coverage_error(labels, scores) - 1) / (n_classes - 1) # type: ignore[no-any-return]
246+
return float(1 - (coverage_error(labels, scores) - 1) / (n_classes - 1))
246247

247248

248249
def scoring_neg_ranking_loss(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -258,7 +259,7 @@ def scoring_neg_ranking_loss(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYP
258259
:param scores: for each utterance, this list contains scores for each of `n_classes` classes
259260
:return: Score of the scoring metric
260261
"""
261-
return -label_ranking_loss(labels, scores) # type: ignore[no-any-return]
262+
return float(-label_ranking_loss(labels, scores))
262263

263264

264265
def scoring_map(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
@@ -274,4 +275,4 @@ def scoring_map(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
274275
:param scores: for each sample, this list contains scores for each of `n_classes` classes
275276
:return: mean average precision score
276277
"""
277-
return label_ranking_average_precision_score(labels, scores) # type: ignore[no-any-return]
278+
return float(label_ranking_average_precision_score(labels, scores))

autointent/modules/abc/_base.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from autointent.context import Context
1111
from autointent.context.optimization_info import Artifact
1212
from autointent.custom_types import BaseMetadataDict
13-
from autointent.metrics import METRIC_FN
1413

1514

1615
class Module(ABC):
@@ -35,14 +34,13 @@ def score(
3534
self,
3635
context: Context,
3736
split: Literal["validation", "test"],
38-
metric_fn: METRIC_FN,
39-
) -> float:
37+
) -> dict[str, float | str]:
4038
"""
4139
Calculate metric on test set and return metric value.
4240
4341
:param context: Context to score
4442
:param split: Split to score on
45-
:param metric_fn: Metric function
43+
:return: Computed metrics value for the test set or error code of metrics
4644
"""
4745

4846
@abstractmethod
@@ -104,3 +102,20 @@ def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> "Module":
104102
def get_embedder_name(self) -> str | None:
105103
"""Experimental method."""
106104
return None
105+
106+
@staticmethod
107+
def score_metrics(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float | str]:
108+
"""
109+
Score metrics on the test set.
110+
111+
:param params: Params to score
112+
:param metrics_dict:
113+
:return:
114+
"""
115+
metrics = {}
116+
for metric_name, metric_fn in metrics_dict.items():
117+
try:
118+
metrics[metric_name] = metric_fn(*params)
119+
except Exception as e: # noqa: PERF203, BLE001
120+
metrics[metric_name] = str(e)
121+
return metrics

autointent/modules/abc/_decision.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from autointent import Context
1010
from autointent.context.optimization_info import DecisionArtifact
1111
from autointent.custom_types import LabelType
12-
from autointent.metrics import DecisionMetricFn
12+
from autointent.metrics import PREDICTION_METRICS_MULTICLASS
1313
from autointent.modules.abc import Module
1414
from autointent.schemas import Tag
1515

@@ -44,19 +44,17 @@ def score(
4444
self,
4545
context: Context,
4646
split: Literal["validation", "test"],
47-
metric_fn: DecisionMetricFn,
48-
) -> float:
47+
) -> dict[str, float | str]:
4948
"""
5049
Calculate metric on test set and return metric value.
5150
5251
:param context: Context to score
5352
:param split: Target split
54-
:param metric_fn: Metric function
55-
:return: Score
53+
:return: Computed metrics value for the test set or error code of metrics
5654
"""
5755
labels, scores = get_decision_evaluation_data(context, split)
5856
self._decisions = self.predict(scores)
59-
return metric_fn(labels, self._decisions)
57+
return self.score_metrics((labels, self._decisions), PREDICTION_METRICS_MULTICLASS)
6058

6159
def get_assets(self) -> DecisionArtifact:
6260
"""Return useful assets that represent intermediate data into context."""

autointent/modules/abc/_scoring.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from autointent import Context
99
from autointent.context.optimization_info import ScorerArtifact
1010
from autointent.custom_types import Split
11-
from autointent.metrics import ScoringMetricFn
11+
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
1212
from autointent.modules.abc import Module
1313

1414

@@ -24,15 +24,13 @@ def score(
2424
self,
2525
context: Context,
2626
split: Literal["validation", "test"],
27-
metric_fn: ScoringMetricFn,
28-
) -> float:
27+
) -> dict[str, float | str]:
2928
"""
3029
Evaluate the scorer on a test set and compute the specified metric.
3130
3231
:param context: Context containing test set and other data.
3332
:param split: Target split
34-
:param metric_fn: Function to compute the scoring metric.
35-
:return: Computed metric value for the test set.
33+
:return: Computed metrics value for the test set or error code of metrics
3634
"""
3735
if split == "validation":
3836
utterances = context.data_handler.validation_utterances(0)
@@ -58,7 +56,8 @@ def score(
5856
self._validation_scores = self.predict(context.data_handler.validation_utterances(1))
5957
self._test_scores = self.predict(context.data_handler.test_utterances())
6058

61-
return metric_fn(labels, scores)
59+
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
60+
return self.score_metrics((labels, scores), metrics_dict)
6261

6362
def get_assets(self) -> ScorerArtifact:
6463
"""

0 commit comments

Comments
 (0)