Skip to content

Commit befe416

Browse files
authored
Add AUROC and AUPRC metrics for binary classification tasks (#244)
* add auroc metric * add auprc metric * update metric_for_best_model help message
1 parent 6750e05 commit befe416

File tree

4 files changed

+22
-16
lines changed

4 files changed

+22
-16
lines changed

src/cnlpt/_cli/train.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ def transformers_arg_option(field_name: str, *args, **kwargs):
336336
FinalTaskWeightArg = Annotated[float, training_arg_option("final_task_weight")]
337337
FreezeEncoderArg = Annotated[float, training_arg_option("freeze_encoder")]
338338
BiasFitArg = Annotated[bool, training_arg_option("bias_fit")]
339-
ReportProbsArg = Annotated[bool, training_arg_option("report_probs")]
340339
EvalsPerEpochArg = Annotated[int, training_arg_option("evals_per_epoch")]
341340
RichDisplayArg = Annotated[bool, training_arg_option("rich_display")]
342341
LoggingStrategyArg = Annotated[
@@ -415,7 +414,6 @@ def train(
415414
final_task_weight: FinalTaskWeightArg = 1.0,
416415
freeze_encoder: FreezeEncoderArg = 0.0,
417416
bias_fit: BiasFitArg = False,
418-
report_probs: ReportProbsArg = False,
419417
evals_per_epoch: EvalsPerEpochArg = 0,
420418
rich_display: RichDisplayArg = True,
421419
logging_strategy: LoggingStrategyArg = IntervalStrategy.EPOCH,
@@ -537,7 +535,6 @@ def train(
537535
final_task_weight=final_task_weight,
538536
freeze_encoder=freeze_encoder,
539537
bias_fit=bias_fit,
540-
report_probs=report_probs,
541538
evals_per_epoch=evals_per_epoch,
542539
rich_display=rich_display,
543540
logging_strategy=logging_strategy,

src/cnlpt/train_system/args.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ def __post_init__(self):
4747
"help": "Only optimize the bias parameters of the encoder (and the weights of the classifier heads), as proposed in the BitFit paper by Ben Zaken et al. 2021 (https://arxiv.org/abs/2106.10199)."
4848
},
4949
)
50-
report_probs: bool = field(
51-
default=False,
52-
metadata={
53-
"help": "If selected, probability scores will be added to the output prediction file for test data when used with --do_predict."
54-
},
55-
)
5650
evals_per_epoch: int = field(
5751
default=0,
5852
metadata={
@@ -85,6 +79,6 @@ def __post_init__(self):
8579
metric_for_best_model: Union[str, None] = field(
8680
default="avg_macro_f1",
8781
metadata={
88-
"help": 'The metric to use to compare two different models. Average across tasks with "avg_[acc|macro_f1|micro_f1]". Optimize for a specific task with "taskname.[acc|macro_f1|micro_f1]". Optimize for a particular label with "taskname.labelname.f1". Average multiple metrics with "METRIC_1,METRIC_2".'
82+
"help": 'The metric to use to compare two different models. Average across tasks with "avg_[acc|macro_f1|micro_f1]". Optimize for a specific task with "taskname.[acc|macro_f1|micro_f1]". Optimize for a particular label with "taskname.labelname.f1". For binary classification tasks, optimize for AUROC with "taskname.auroc" or for AUPRC with "taskname.labelname.auprc". Average multiple metrics with "METRIC_1,METRIC_2".'
8983
},
9084
)

src/cnlpt/train_system/cnlp_train_system.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,7 @@ def _extract_task_predictions(self, p: EvalPrediction):
8585
preds = np.argmax(raw_preds, axis=3)
8686
else:
8787
preds = np.argmax(raw_preds, axis=1)
88-
if self.args.report_probs:
89-
probs = np.max(
90-
[simple_softmax(logits) for logits in raw_preds],
91-
axis=1,
92-
)
88+
probs = np.array([simple_softmax(logits) for logits in raw_preds])
9389

9490
labels: Union[npt.NDArray[np.int64], None]
9591
task_label_width = 0

src/cnlpt/train_system/metrics.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
from typing import Union
33

44
import numpy as np
5-
from sklearn.metrics import classification_report
5+
from sklearn.metrics import (
6+
average_precision_score,
7+
classification_report,
8+
roc_auc_score,
9+
)
10+
11+
from cnlpt.data import CLASSIFICATION
612

713
from ..data.preprocess import MASK_VALUE
814
from ..data.task_info import TaskInfo
@@ -48,4 +54,17 @@ def compute_metrics(self) -> dict[str, float]:
4854
**{f"{label}.f1": report[label]["f1-score"] for label in self.task.labels},
4955
}
5056

57+
if (
58+
self.task.type == CLASSIFICATION
59+
and len(self.task.labels) == 2
60+
and self.probs is not None
61+
):
62+
task_metrics["auroc"] = roc_auc_score(labels, self.probs[pred_inds[0], 1])
63+
for label in self.task.labels:
64+
task_metrics[f"{label}.auprc"] = average_precision_score(
65+
labels,
66+
self.probs[pred_inds[0], 1],
67+
pos_label=self.task.get_label_id(label),
68+
)
69+
5170
return {f"{self.task.name}.{key}": val for key, val in task_metrics.items()}

0 commit comments

Comments
 (0)