Skip to content

Commit fcaceb5

Browse files
committed
update macro-f1 to return classwise scores
1 parent 1d8a7c3 commit fcaceb5

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Literal
2+
13
import torch
24
import torchmetrics
35

@@ -32,7 +34,11 @@ class MacroF1(torchmetrics.Metric):
3234
"""
3335

3436
def __init__(
35-
self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5
37+
self,
38+
num_labels: int,
39+
dist_sync_on_step: bool = False,
40+
threshold: float = 0.5,
41+
average: Literal["mean", "none"] | None = "mean",
3642
):
3743
super().__init__(dist_sync_on_step=dist_sync_on_step)
3844

@@ -52,6 +58,7 @@ def __init__(
5258
dist_reduce_fx="sum",
5359
)
5460
self.threshold = threshold
61+
self.average = average
5562

5663
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
5764
"""
@@ -85,7 +92,14 @@ def compute(self) -> torch.Tensor:
8592
classwise_f1 = 2 * precision * recall / (precision + recall)
8693
# if (precision and recall are 0) or (precision is nan), set f1 to 0
8794
classwise_f1 = classwise_f1.nan_to_num()
88-
return torch.mean(classwise_f1)
95+
96+
if self.average == "mean":
97+
return torch.mean(classwise_f1)
98+
99+
if self.average is None or self.average == "none":
100+
return classwise_f1
101+
102+
raise ValueError(f"{self.average} not supported")
89103

90104

91105
class BalancedAccuracy(torchmetrics.Metric):

0 commit comments

Comments
 (0)