11"""BertScorer class for transformer-based classification."""
22
33import tempfile
4- from typing import Any
4+ from collections .abc import Callable
5+ from typing import Any , Literal
56
7+ import evaluate
68import numpy as np
79import numpy .typing as npt
810import torch
2224from autointent ._callbacks import REPORTERS_NAMES
2325from autointent .configs import HFModelConfig
2426from autointent .custom_types import ListOfLabels
25- from autointent .metrics .scoring import scoring_f1
2627from autointent .modules .base import BaseScorer
2728
2829
@@ -33,7 +34,7 @@ class BertScorer(BaseScorer):
3334 _model : Any # transformers AutoModel factory returns Any
3435 _tokenizer : Any # transformers AutoTokenizer factory returns Any
3536
36- def __init__ (
37+ def __init__ ( # noqa: PLR0913
3738 self ,
3839 classification_model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
3940 num_train_epochs : int = 3 ,
@@ -44,6 +45,8 @@ def __init__(
4445 val_fraction : float = 0.2 ,
4546 early_stopping_patience : int = 1 ,
4647 early_stopping_threshold : float = 0.0 ,
48+ early_stopping_metric : Literal ["f1" , "accuracy" , "recall" , "precision" ] = "f1" ,
49+ early_stopping_metric_averaging : Literal ["binary" , "macro" , "micro" ] = "macro" , # doesnt affect `accuracy`
4750 ) -> None :
4851 self .classification_model_config = HFModelConfig .from_search_config (classification_model_config )
4952 self .num_train_epochs = num_train_epochs
@@ -54,9 +57,11 @@ def __init__(
5457 self .val_fraction = val_fraction
5558 self .early_stopping_patience = early_stopping_patience
5659 self .early_stopping_threshold = early_stopping_threshold
60+ self .early_stopping_metric = early_stopping_metric
61+ self .early_stopping_metric_averaging = early_stopping_metric_averaging
5762
5863 @classmethod
59- def from_context (
64+ def from_context ( # noqa: PLR0913
6065 cls ,
6166 context : Context ,
6267 classification_model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
@@ -67,6 +72,8 @@ def from_context(
6772 val_fraction : float = 0.2 ,
6873 early_stopping_patience : int = 1 ,
6974 early_stopping_threshold : float = 0.0 ,
75+ early_stopping_metric : Literal ["f1" , "accuracy" , "recall" , "precision" ] = "f1" ,
76+ early_stopping_metric_averaging : Literal ["binary" , "macro" , "micro" ] = "macro" ,
7077 ) -> "BertScorer" :
7178 if classification_model_config is None :
7279 classification_model_config = context .resolve_transformer ()
@@ -83,6 +90,8 @@ def from_context(
8390 val_fraction = val_fraction ,
8491 early_stopping_patience = early_stopping_patience ,
8592 early_stopping_threshold = early_stopping_threshold ,
93+ early_stopping_metric = early_stopping_metric ,
94+ early_stopping_metric_averaging = early_stopping_metric_averaging ,
8695 )
8796
8897 def get_implicit_initialization_params (self ) -> dict [str , Any ]:
@@ -136,11 +145,6 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
136145
137146 tokenized_dataset = dataset .map (tokenize_function , batched = True , batch_size = self .batch_size )
138147
139- metric_name = "eval_f1"
140-
141- def compute_metrics (predictions : EvalPrediction ) -> dict [str , float ]:
142- return {metric_name : scoring_f1 (predictions .label_ids .tolist (), predictions .predictions .tolist ())} # type: ignore[union-attr]
143-
144148 with tempfile .TemporaryDirectory () as tmp_dir :
145149 training_args = TrainingArguments (
146150 output_dir = tmp_dir ,
@@ -154,7 +158,7 @@ def compute_metrics(predictions: EvalPrediction) -> dict[str, float]:
154158 logging_steps = 10 ,
155159 report_to = self .report_to if self .report_to is not None else "none" ,
156160 use_cpu = self .classification_model_config .device == "cpu" ,
157- metric_for_best_model = metric_name ,
161+ metric_for_best_model = self . early_stopping_metric ,
158162 load_best_model_at_end = True ,
159163 )
160164
@@ -165,7 +169,7 @@ def compute_metrics(predictions: EvalPrediction) -> dict[str, float]:
165169 eval_dataset = tokenized_dataset ["validation" ],
166170 processing_class = self ._tokenizer ,
167171 data_collator = DataCollatorWithPadding (tokenizer = self ._tokenizer ),
168- compute_metrics = compute_metrics ,
172+ compute_metrics = self . _get_compute_metrics () ,
169173 callbacks = [
170174 EarlyStoppingCallback (
171175 early_stopping_patience = self .early_stopping_patience ,
@@ -178,6 +182,27 @@ def compute_metrics(predictions: EvalPrediction) -> dict[str, float]:
178182
179183 self ._model .eval ()
180184
185+ def _get_compute_metrics (self ) -> Callable [[EvalPrediction ], dict [str , float ]]:
186+ """Construct callable for computing metrics during transformer training.
187+
188+ The result of this function is supposed to pass to :py:class:`transformers.Trainer`.
189+ """
190+ metric_fn = evaluate .load (self .early_stopping_metric )
191+
192+ compute_kwargs = {}
193+
194+ if self .early_stopping_metric in ["f1" , "recall" , "precision" ]:
195+ compute_kwargs ["average" ] = self .early_stopping_metric_averaging
196+
197+ def compute_metrics (output : EvalPrediction ) -> dict [str , float ]:
198+ return metric_fn .compute (
199+ predictions = output .predictions .argmax (axis = - 1 ).tolist (),
200+ references = output .label_ids ,
201+ ** compute_kwargs ,
202+ )
203+
204+ return compute_metrics
205+
181206 def predict (self , utterances : list [str ]) -> npt .NDArray [Any ]:
182207 if not hasattr (self , "_model" ) or not hasattr (self , "_tokenizer" ):
183208 msg = "Model is not trained. Call fit() first."
0 commit comments