1+ from typing import Literal
2+
13import torch
24import torchmetrics
35
@@ -32,7 +34,11 @@ class MacroF1(torchmetrics.Metric):
3234 """
3335
3436 def __init__ (
35- self , num_labels : int , dist_sync_on_step : bool = False , threshold : float = 0.5
37+ self ,
38+ num_labels : int ,
39+ dist_sync_on_step : bool = False ,
40+ threshold : float = 0.5 ,
41+ average : Literal ["mean" , "none" ] | None = "mean" ,
3642 ):
3743 super ().__init__ (dist_sync_on_step = dist_sync_on_step )
3844
@@ -52,6 +58,7 @@ def __init__(
5258 dist_reduce_fx = "sum" ,
5359 )
5460 self .threshold = threshold
61+ self .average = average
5562
5663 def update (self , preds : torch .Tensor , labels : torch .Tensor ) -> None :
5764 """
@@ -85,7 +92,14 @@ def compute(self) -> torch.Tensor:
8592 classwise_f1 = 2 * precision * recall / (precision + recall )
8693 # if (precision and recall are 0) or (precision is nan), set f1 to 0
8794 classwise_f1 = classwise_f1 .nan_to_num ()
88- return torch .mean (classwise_f1 )
95+
96+ if self .average == "mean" :
97+ return torch .mean (classwise_f1 )
98+
99+ if self .average is None or self .average == "none" :
100+ return classwise_f1
101+
102+ raise ValueError (f"{ self .average } not supported" )
89103
90104
91105class BalancedAccuracy (torchmetrics .Metric ):
0 commit comments