@@ -76,7 +76,9 @@ def _micro_F1(self):
7676 precision = tp / (tp + fp + 1e-8 ) # Avoid division by zero
7777 recall = tp / (tp + fn + 1e-8 ) # Avoid division by zero
7878
79- f1 = 2 * precision * recall / (precision + recall + 1e-8 ) # Avoid division by zero
79+ f1 = (
80+ 2 * precision * recall / (precision + recall + 1e-8 )
81+ ) # Avoid division by zero
8082 return f1
8183
8284 def _macro_F1 (self ):
@@ -91,10 +93,18 @@ def _macro_F1(self):
9193 torch.Tensor
9294 The macro-averaged F1 score.
9395 """
94- precision_per_class = self .tp / (self .tp + self .fp + 1e-8 ) # Avoid division by zero
95- recall_per_class = self .tp / (self .tp + self .fn + 1e-8 ) # Avoid division by zero
96- f1_per_class = 2 * precision_per_class * recall_per_class / (
97- precision_per_class + recall_per_class + 1e-8 ) # Avoid division by zero
96+ precision_per_class = self .tp / (
97+ self .tp + self .fp + 1e-8
98+ ) # Avoid division by zero
99+ recall_per_class = self .tp / (
100+ self .tp + self .fn + 1e-8
101+ ) # Avoid division by zero
102+ f1_per_class = (
103+ 2
104+ * precision_per_class
105+ * recall_per_class
106+ / (precision_per_class + recall_per_class + 1e-8 )
107+ ) # Avoid division by zero
98108
99109 # Take the average of F1 scores across all classes
100110 f1_score = torch .mean (f1_per_class )
@@ -138,4 +148,3 @@ def forward(self, preds, target):
138148 f1_score = self ._micro_F1 ()
139149
140150 return f1_score
141-
0 commit comments