1- from typing import Literal
2-
31import torch
42import torchmetrics
53
@@ -34,11 +32,7 @@ class MacroF1(torchmetrics.Metric):
3432 """
3533
3634 def __init__ (
37- self ,
38- num_labels : int ,
39- dist_sync_on_step : bool = False ,
40- threshold : float = 0.5 ,
41- average : Literal ["mean" , "none" ] | None = "mean" ,
35+ self , num_labels : int , dist_sync_on_step : bool = False , threshold : float = 0.5
4236 ):
4337 super ().__init__ (dist_sync_on_step = dist_sync_on_step )
4438
@@ -58,7 +52,6 @@ def __init__(
5852 dist_reduce_fx = "sum" ,
5953 )
6054 self .threshold = threshold
61- self .average = average
6255
6356 def update (self , preds : torch .Tensor , labels : torch .Tensor ) -> None :
6457 """
@@ -92,14 +85,7 @@ def compute(self) -> torch.Tensor:
9285 classwise_f1 = 2 * precision * recall / (precision + recall )
9386 # if (precision and recall are 0) or (precision is nan), set f1 to 0
9487 classwise_f1 = classwise_f1 .nan_to_num ()
95-
96- if self .average == "mean" :
97- return torch .mean (classwise_f1 )
98-
99- if self .average is None or self .average == "none" :
100- return classwise_f1
101-
102- raise ValueError (f"{ self .average } not supported" )
88+ return torch .mean (classwise_f1 )
10389
10490
10591class BalancedAccuracy (torchmetrics .Metric ):
0 commit comments