@@ -23,7 +23,6 @@ def __init__(self, num_classes, macro_averaging=False):
2323 self .y_true = []
2424 self .y_pred = []
2525
26-
2726 def forward (self , target , preds ):
2827 """
2928 Stores predictions and targets for computing the F1 score.
@@ -57,7 +56,11 @@ def compute_f1(self):
5756 y_true = torch .cat (self .y_true )
5857 y_pred = torch .cat (self .y_pred )
5958
60- return self ._macro_F1 (y_true , y_pred ) if self .macro_averaging else self ._micro_F1 (y_true , y_pred )
59+ return (
60+ self ._macro_F1 (y_true , y_pred )
61+ if self .macro_averaging
62+ else self ._micro_F1 (y_true , y_pred )
63+ )
6164
6265 def _micro_F1 (self , target , preds ):
6366 """Computes Micro F1 Score (global TP, FP, FN)."""
@@ -111,9 +114,13 @@ def __returnmetric__(self):
111114 y_true = torch .cat ([t .unsqueeze (0 ) if t .dim () == 0 else t for t in self .y_true ])
112115 y_pred = torch .cat ([t .unsqueeze (0 ) if t .dim () == 0 else t for t in self .y_pred ])
113116
114- return self ._macro_F1 (y_true , y_pred ) if self .macro_averaging else self ._micro_F1 (y_true , y_pred )
117+ return (
118+ self ._macro_F1 (y_true , y_pred )
119+ if self .macro_averaging
120+ else self ._micro_F1 (y_true , y_pred )
121+ )
115122
116123 def __reset__ (self ):
117124 """Resets stored predictions and targets."""
118125 self .y_true = []
119- self .y_pred = []
126+ self .y_pred = []
0 commit comments