44
55class F1Score (nn .Module ):
66 """
7- F1 Score implementation with direct averaging inside the compute method.
7+ F1 Score implementation with support for both macro and micro averaging.
8+
9+ This class computes the F1 score during training using either macro or micro averaging.
10+ The F1 score is calculated based on the true positives (TP), false positives (FP),
11+ and false negatives (FN) for each class.
812
913 Parameters
1014 ----------
1115 num_classes : int
12- Number of classes.
16+ The number of classes in the classification task.
17+
18+ macro_averaging : bool, optional, default=False
19+ If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
1320
1421 Attributes
1522 ----------
1623 num_classes : int
17- The number of classes.
24+ The number of classes in the classification task .
1825
1926 tp : torch.Tensor
20- Tensor for True Positives (TP) for each class.
27+ Tensor storing the count of True Positives (TP) for each class.
2128
2229 fp : torch.Tensor
23- Tensor for False Positives (FP) for each class.
30+ Tensor storing the count of False Positives (FP) for each class.
2431
2532 fn : torch.Tensor
26- Tensor for False Negatives (FN) for each class.
33+ Tensor storing the count of False Negatives (FN) for each class.
34+
35+ macro_averaging : bool
36+ A flag indicating whether to compute the macro-averaged F1 score or not.
2737 """
2838
29- def __init__ (self , num_classes ):
39+ def __init__ (self , num_classes , macro_averaging = False ):
3040 """
3141 Initializes the F1Score object, setting up the necessary state variables.
3242
@@ -35,28 +45,81 @@ def __init__(self, num_classes):
3545 num_classes : int
3646 The number of classes in the classification task.
3747
48+ macro_averaging : bool, optional, default=False
49+ If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
3850 """
39-
4051 super ().__init__ ()
4152
4253 self .num_classes = num_classes
54+ self .macro_averaging = macro_averaging
4355
44- # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
56+ # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
4557 self .tp = torch .zeros (num_classes )
4658 self .fp = torch .zeros (num_classes )
4759 self .fn = torch .zeros (num_classes )
4860
61+ def _micro_F1 (self ):
62+ """
63+ Compute the Micro F1 score by aggregating TP, FP, and FN across all classes.
64+
65+ Micro F1 score is calculated globally by considering all predictions together, regardless of class.
66+
67+ Returns
68+ -------
69+ torch.Tensor
70+ The micro-averaged F1 score.
71+ """
72+ tp = torch .sum (self .tp )
73+ fp = torch .sum (self .fp )
74+ fn = torch .sum (self .fn )
75+
76+ precision = tp / (tp + fp + 1e-8 ) # Avoid division by zero
77+ recall = tp / (tp + fn + 1e-8 ) # Avoid division by zero
78+
79+ f1 = 2 * precision * recall / (precision + recall + 1e-8 ) # Avoid division by zero
80+ return f1
81+
82+ def _macro_F1 (self ):
83+ """
84+ Compute the Macro F1 score by calculating the F1 score per class and averaging.
85+
86+ Macro F1 score is calculated as the average of per-class F1 scores. This approach treats all classes equally,
87+ regardless of their frequency.
88+
89+ Returns
90+ -------
91+ torch.Tensor
92+ The macro-averaged F1 score.
93+ """
94+ precision_per_class = self .tp / (self .tp + self .fp + 1e-8 ) # Avoid division by zero
95+ recall_per_class = self .tp / (self .tp + self .fn + 1e-8 ) # Avoid division by zero
96+ f1_per_class = 2 * precision_per_class * recall_per_class / (
97+ precision_per_class + recall_per_class + 1e-8 ) # Avoid division by zero
98+
99+ # Take the average of F1 scores across all classes
100+ f1_score = torch .mean (f1_per_class )
101+ return f1_score
102+
49103 def forward (self , preds , target ):
50104 """
51- Update the variables with predictions and true labels.
105+ Update the True Positives, False Positives, and False Negatives, and compute the F1 score.
106+
107+ This method computes the F1 score based on the predictions and true labels. It can compute either the
108+ macro-averaged or micro-averaged F1 score, depending on the `macro_averaging` flag.
52109
53110 Parameters
54111 ----------
55112 preds : torch.Tensor
56- Predicted logits (shape: [batch_size, num_classes]).
113+ Predicted logits or class indices (shape: [batch_size, num_classes]).
114+ These logits are typically the output of a softmax or sigmoid activation.
57115
58116 target : torch.Tensor
59- True labels (shape: [batch_size]).
117+ True labels (shape: [batch_size]), where each element is an integer representing the true class.
118+
119+ Returns
120+ -------
121+ torch.Tensor
122+ The computed F1 score (either micro or macro, based on `macro_averaging`).
60123 """
61124 preds = torch .argmax (preds , dim = 1 )
62125
@@ -66,10 +129,11 @@ def forward(self, preds, target):
66129 self .fp [i ] += torch .sum ((preds == i ) & (target != i )).float ()
67130 self .fn [i ] += torch .sum ((preds != i ) & (target == i )).float ()
68131
69- f1_score = (
70- 2
71- * torch .sum (self .tp )
72- / (2 * torch .sum (self .tp ) + torch .sum (self .fp ) + torch .sum (self .fn ))
73- )
132+ if self .macro_averaging :
133+ # Calculate Macro F1 score
134+ f1_score = self ._macro_F1 ()
135+ else :
136+ # Calculate Micro F1 score
137+ f1_score = self ._micro_F1 ()
74138
75- return f1_score
139+ return f1_score
0 commit comments