Skip to content

Commit ec55bfb

Browse files
committed
Revert "update macro-f1 to return classwise scores"
This reverts commit fcaceb5.
1 parent 8688c3b commit ec55bfb

File tree

1 file changed

+2
-16
lines changed

1 file changed

+2
-16
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Literal
2-
31
import torch
42
import torchmetrics
53

@@ -34,11 +32,7 @@ class MacroF1(torchmetrics.Metric):
3432
"""
3533

3634
def __init__(
37-
self,
38-
num_labels: int,
39-
dist_sync_on_step: bool = False,
40-
threshold: float = 0.5,
41-
average: Literal["mean", "none"] | None = "mean",
35+
self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5
4236
):
4337
super().__init__(dist_sync_on_step=dist_sync_on_step)
4438

@@ -58,7 +52,6 @@ def __init__(
5852
dist_reduce_fx="sum",
5953
)
6054
self.threshold = threshold
61-
self.average = average
6255

6356
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
6457
"""
@@ -92,14 +85,7 @@ def compute(self) -> torch.Tensor:
9285
classwise_f1 = 2 * precision * recall / (precision + recall)
9386
# if (precision and recall are 0) or (precision is nan), set f1 to 0
9487
classwise_f1 = classwise_f1.nan_to_num()
95-
96-
if self.average == "mean":
97-
return torch.mean(classwise_f1)
98-
99-
if self.average is None or self.average == "none":
100-
return classwise_f1
101-
102-
raise ValueError(f"{self.average} not supported")
88+
return torch.mean(classwise_f1)
10389

10490

10591
class BalancedAccuracy(torchmetrics.Metric):

0 commit comments

Comments
 (0)