Skip to content

Commit e72bde5

Browse files
committed
use our metrics
1 parent fc567ee commit e72bde5

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

autointent/configs/_transformers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
55
from typing_extensions import Self, assert_never
66

7+
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
8+
79

810
class TokenizerConfig(BaseModel):
911
padding: bool | Literal["longest", "max_length", "do_not_pad"] = True
@@ -128,5 +130,4 @@ class EarlyStoppingConfig(BaseModel):
128130
val_fraction: float = 0.2
129131
patience: int = 1
130132
threshold: float = 0.0
131-
metric: Literal["f1", "accuracy", "recall", "precision"] | None = "f1"
132-
averaging: Literal["macro", "micro"] = "macro" # doesnt affect `accuracy`
133+
metric: Literal[tuple((SCORING_METRICS_MULTILABEL | SCORING_METRICS_MULTICLASS).keys())] | None = "scoring_f1" # type: ignore[valid-type]

autointent/modules/scoring/_bert.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections.abc import Callable
55
from typing import Any
66

7-
import evaluate
87
import numpy as np
98
import numpy.typing as npt
109
import torch
@@ -25,6 +24,7 @@
2524
from autointent._callbacks import REPORTERS_NAMES
2625
from autointent.configs import EarlyStoppingConfig, HFModelConfig
2726
from autointent.custom_types import ListOfLabels
27+
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
2828
from autointent.modules.base import BaseScorer
2929

3030

@@ -191,19 +191,13 @@ def _get_compute_metrics(self) -> Callable[[EvalPrediction], dict[str, float]] |
191191
if self.early_stopping_config.metric is None:
192192
return None
193193

194-
metric_fn = evaluate.load(self.early_stopping_config.metric)
195-
196-
compute_kwargs = {}
197-
198-
if self.early_stopping_config.metric in ["f1", "recall", "precision"]:
199-
compute_kwargs["average"] = self.early_stopping_config.averaging
194+
metric_name = self.early_stopping_config.metric
195+
metric_fn = (SCORING_METRICS_MULTILABEL | SCORING_METRICS_MULTICLASS)[metric_name]
200196

201197
def compute_metrics(output: EvalPrediction) -> dict[str, float]:
202-
return metric_fn.compute( # type: ignore[no-any-return]
203-
predictions=output.predictions.argmax(axis=-1).tolist(), # type: ignore[union-attr]
204-
references=output.label_ids,
205-
**compute_kwargs,
206-
)
198+
return {
199+
metric_name: metric_fn(output.label_ids.tolist(), output.predictions.tolist()) # type: ignore[union-attr]
200+
}
207201

208202
return compute_metrics
209203

0 commit comments

Comments
 (0)