1+ import numpy as np
12import torch
23import torch .nn as nn
34
@@ -52,13 +53,14 @@ def __init__(self, num_classes, macro_averaging=False):
5253
5354 self .num_classes = num_classes
5455 self .macro_averaging = macro_averaging
55-
56+ self .y_true = []
57+ self .y_pred = []
5658 # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
5759 self .tp = torch .zeros (num_classes )
5860 self .fp = torch .zeros (num_classes )
5961 self .fn = torch .zeros (num_classes )
6062
61- def _micro_F1 (self ):
63+ def _micro_F1 (self , target , preds ):
6264 """
6365 Compute the Micro F1 score by aggregating TP, FP, and FN across all classes.
6466
@@ -69,6 +71,11 @@ def _micro_F1(self):
6971 torch.Tensor
7072 The micro-averaged F1 score.
7173 """
74+ for i in range (self .num_classes ):
75+ self .tp [i ] += torch .sum ((preds == i ) & (target == i )).float ()
76+ self .fp [i ] += torch .sum ((preds == i ) & (target != i )).float ()
77+ self .fn [i ] += torch .sum ((preds != i ) & (target == i )).float ()
78+
7279 tp = torch .sum (self .tp )
7380 fp = torch .sum (self .fp )
7481 fn = torch .sum (self .fn )
@@ -81,7 +88,7 @@ def _micro_F1(self):
8188 ) # Avoid division by zero
8289 return f1
8390
84- def _macro_F1 (self ):
91+ def _macro_F1 (self , target , preds ):
8592 """
8693 Compute the Macro F1 score by calculating the F1 score per class and averaging.
8794
@@ -93,6 +100,12 @@ def _macro_F1(self):
93100 torch.Tensor
94101 The macro-averaged F1 score.
95102 """
103+ # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
104+ for i in range (self .num_classes ):
105+ self .tp [i ] += torch .sum ((preds == i ) & (target == i )).float ()
106+ self .fp [i ] += torch .sum ((preds == i ) & (target != i )).float ()
107+ self .fn [i ] += torch .sum ((preds != i ) & (target == i )).float ()
108+
96109 precision_per_class = self .tp / (
97110 self .tp + self .fp + 1e-8
98111 ) # Avoid division by zero
@@ -133,18 +146,24 @@ def forward(self, preds, target):
133146 The computed F1 score (either micro or macro, based on `macro_averaging`).
134147 """
135148 preds = torch .argmax (preds , dim = - 1 )
149+ self .y_true .append (target )
150+ self .y_pred .append (preds )
151+
152+ def __returnmetric__ (self ):
153+ if self .y_true == [] or self .y_pred == []:
154+ return np .nan
155+ if isinstance (self .y_true , list ):
156+ if len (self .y_true ) == 1 :
157+ self .y_true = self .y_true [0 ]
158+ self .y_pred = self .y_pred [0 ]
159+ else :
160+ self .y_true = torch .cat (self .y_true )
161+ self .y_pred = torch .cat (self .y_pred )
162+ return self ._micro_F1 (self .y_true , self .y_pred ) if not self .macro_averaging else self ._macro_F1 (self .y_true , self .y_pred )
163+
164+ def __reset__ (self ):
165+ self .y_true = []
166+ self .y_pred = []
167+ return None
136168
137- # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
138- for i in range (self .num_classes ):
139- self .tp [i ] += torch .sum ((preds == i ) & (target == i )).float ()
140- self .fp [i ] += torch .sum ((preds == i ) & (target != i )).float ()
141- self .fn [i ] += torch .sum ((preds != i ) & (target == i )).float ()
142169
143- if self .macro_averaging :
144- # Calculate Macro F1 score
145- f1_score = self ._macro_F1 ()
146- else :
147- # Calculate Micro F1 score
148- f1_score = self ._micro_F1 ()
149-
150- return f1_score
0 commit comments