Skip to content

Commit 4ef46df

Browse files
committed
ty to fix "not found f1" error
1 parent f9ca4c4 commit 4ef46df

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

autointent/modules/scoring/_bert.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""BertScorer class for transformer-based classification."""
22

33
import tempfile
4-
from typing import Any
4+
from collections.abc import Callable
5+
from typing import Any, Literal
56

7+
import evaluate
68
import numpy as np
79
import numpy.typing as npt
810
import torch
@@ -22,7 +24,6 @@
2224
from autointent._callbacks import REPORTERS_NAMES
2325
from autointent.configs import HFModelConfig
2426
from autointent.custom_types import ListOfLabels
25-
from autointent.metrics.scoring import scoring_f1
2627
from autointent.modules.base import BaseScorer
2728

2829

@@ -33,7 +34,7 @@ class BertScorer(BaseScorer):
3334
_model: Any # transformers AutoModel factory returns Any
3435
_tokenizer: Any # transformers AutoTokenizer factory returns Any
3536

36-
def __init__(
37+
def __init__( # noqa: PLR0913
3738
self,
3839
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
3940
num_train_epochs: int = 3,
@@ -44,6 +45,8 @@ def __init__(
4445
val_fraction: float = 0.2,
4546
early_stopping_patience: int = 1,
4647
early_stopping_threshold: float = 0.0,
48+
early_stopping_metric: Literal["f1", "accuracy", "recall", "precision"] = "f1",
49+
early_stopping_metric_averaging: Literal["binary", "macro", "micro"] = "macro", # doesnt affect `accuracy`
4750
) -> None:
4851
self.classification_model_config = HFModelConfig.from_search_config(classification_model_config)
4952
self.num_train_epochs = num_train_epochs
@@ -54,9 +57,11 @@ def __init__(
5457
self.val_fraction = val_fraction
5558
self.early_stopping_patience = early_stopping_patience
5659
self.early_stopping_threshold = early_stopping_threshold
60+
self.early_stopping_metric = early_stopping_metric
61+
self.early_stopping_metric_averaging = early_stopping_metric_averaging
5762

5863
@classmethod
59-
def from_context(
64+
def from_context( # noqa: PLR0913
6065
cls,
6166
context: Context,
6267
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
@@ -67,6 +72,8 @@ def from_context(
6772
val_fraction: float = 0.2,
6873
early_stopping_patience: int = 1,
6974
early_stopping_threshold: float = 0.0,
75+
early_stopping_metric: Literal["f1", "accuracy", "recall", "precision"] = "f1",
76+
early_stopping_metric_averaging: Literal["binary", "macro", "micro"] = "macro",
7077
) -> "BertScorer":
7178
if classification_model_config is None:
7279
classification_model_config = context.resolve_transformer()
@@ -83,6 +90,8 @@ def from_context(
8390
val_fraction=val_fraction,
8491
early_stopping_patience=early_stopping_patience,
8592
early_stopping_threshold=early_stopping_threshold,
93+
early_stopping_metric=early_stopping_metric,
94+
early_stopping_metric_averaging=early_stopping_metric_averaging,
8695
)
8796

8897
def get_implicit_initialization_params(self) -> dict[str, Any]:
@@ -136,11 +145,6 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
136145

137146
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=self.batch_size)
138147

139-
metric_name = "eval_f1"
140-
141-
def compute_metrics(predictions: EvalPrediction) -> dict[str, float]:
142-
return {metric_name: scoring_f1(predictions.label_ids.tolist(), predictions.predictions.tolist())} # type: ignore[union-attr]
143-
144148
with tempfile.TemporaryDirectory() as tmp_dir:
145149
training_args = TrainingArguments(
146150
output_dir=tmp_dir,
@@ -154,7 +158,7 @@ def compute_metrics(predictions: EvalPrediction) -> dict[str, float]:
154158
logging_steps=10,
155159
report_to=self.report_to if self.report_to is not None else "none",
156160
use_cpu=self.classification_model_config.device == "cpu",
157-
metric_for_best_model=metric_name,
161+
metric_for_best_model=self.early_stopping_metric,
158162
load_best_model_at_end=True,
159163
)
160164

@@ -165,7 +169,7 @@ def compute_metrics(predictions: EvalPrediction) -> dict[str, float]:
165169
eval_dataset=tokenized_dataset["validation"],
166170
processing_class=self._tokenizer,
167171
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
168-
compute_metrics=compute_metrics,
172+
compute_metrics=self._get_compute_metrics(),
169173
callbacks=[
170174
EarlyStoppingCallback(
171175
early_stopping_patience=self.early_stopping_patience,
@@ -178,6 +182,27 @@ def compute_metrics(predictions: EvalPrediction) -> dict[str, float]:
178182

179183
self._model.eval()
180184

185+
def _get_compute_metrics(self) -> Callable[[EvalPrediction], dict[str, float]]:
186+
"""Construct callable for computing metrics during transformer training.
187+
188+
The result of this function is supposed to pass to :py:class:`transformers.Trainer`.
189+
"""
190+
metric_fn = evaluate.load(self.early_stopping_metric)
191+
192+
compute_kwargs = {}
193+
194+
if self.early_stopping_metric in ["f1", "recall", "precision"]:
195+
compute_kwargs["average"] = self.early_stopping_metric_averaging
196+
197+
def compute_metrics(output: EvalPrediction) -> dict[str, float]:
198+
return metric_fn.compute(
199+
predictions=output.predictions.argmax(axis=-1).tolist(),
200+
references=output.label_ids,
201+
**compute_kwargs,
202+
)
203+
204+
return compute_metrics
205+
181206
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
182207
if not hasattr(self, "_model") or not hasattr(self, "_tokenizer"):
183208
msg = "Model is not trained. Call fit() first."

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dependencies = [
4646
"transformers[torch] (>=4.49.0,<5.0.0)",
4747
"peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)",
4848
"codecarbon (==2.6)",
49+
"evaluate (>=0.4.3,<0.5.0)",
4950
]
5051

5152
[project.optional-dependencies]

0 commit comments

Comments
 (0)