Skip to content

Commit 26ff680

Browse files
committed
added micro/macro averaging option to F1
1 parent 234b7f6 commit 26ff680

File tree

1 file changed

+82
-18
lines changed

1 file changed

+82
-18
lines changed

utils/metrics/F1.py

Lines changed: 82 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,39 @@
44

55
class F1Score(nn.Module):
66
"""
7-
F1 Score implementation with direct averaging inside the compute method.
7+
F1 Score implementation with support for both macro and micro averaging.
8+
9+
This class computes the F1 score during training using either macro or micro averaging.
10+
The F1 score is calculated based on the true positives (TP), false positives (FP),
11+
and false negatives (FN) for each class.
812
913
Parameters
1014
----------
1115
num_classes : int
12-
Number of classes.
16+
The number of classes in the classification task.
17+
18+
macro_averaging : bool, optional, default=False
19+
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
1320
1421
Attributes
1522
----------
1623
num_classes : int
17-
The number of classes.
24+
The number of classes in the classification task.
1825
1926
tp : torch.Tensor
20-
Tensor for True Positives (TP) for each class.
27+
Tensor storing the count of True Positives (TP) for each class.
2128
2229
fp : torch.Tensor
23-
Tensor for False Positives (FP) for each class.
30+
Tensor storing the count of False Positives (FP) for each class.
2431
2532
fn : torch.Tensor
26-
Tensor for False Negatives (FN) for each class.
33+
Tensor storing the count of False Negatives (FN) for each class.
34+
35+
macro_averaging : bool
36+
A flag indicating whether to compute the macro-averaged F1 score or not.
2737
"""
2838

29-
def __init__(self, num_classes):
39+
def __init__(self, num_classes, macro_averaging=False):
3040
"""
3141
Initializes the F1Score object, setting up the necessary state variables.
3242
@@ -35,28 +45,81 @@ def __init__(self, num_classes):
3545
num_classes : int
3646
The number of classes in the classification task.
3747
48+
macro_averaging : bool, optional, default=False
49+
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
3850
"""
39-
4051
super().__init__()
4152

4253
self.num_classes = num_classes
54+
self.macro_averaging = macro_averaging
4355

44-
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
56+
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
4557
self.tp = torch.zeros(num_classes)
4658
self.fp = torch.zeros(num_classes)
4759
self.fn = torch.zeros(num_classes)
4860

61+
def _micro_F1(self):
62+
"""
63+
Compute the Micro F1 score by aggregating TP, FP, and FN across all classes.
64+
65+
Micro F1 score is calculated globally by considering all predictions together, regardless of class.
66+
67+
Returns
68+
-------
69+
torch.Tensor
70+
The micro-averaged F1 score.
71+
"""
72+
tp = torch.sum(self.tp)
73+
fp = torch.sum(self.fp)
74+
fn = torch.sum(self.fn)
75+
76+
precision = tp / (tp + fp + 1e-8) # Avoid division by zero
77+
recall = tp / (tp + fn + 1e-8) # Avoid division by zero
78+
79+
f1 = 2 * precision * recall / (precision + recall + 1e-8) # Avoid division by zero
80+
return f1
81+
82+
def _macro_F1(self):
83+
"""
84+
Compute the Macro F1 score by calculating the F1 score per class and averaging.
85+
86+
Macro F1 score is calculated as the average of per-class F1 scores. This approach treats all classes equally,
87+
regardless of their frequency.
88+
89+
Returns
90+
-------
91+
torch.Tensor
92+
The macro-averaged F1 score.
93+
"""
94+
precision_per_class = self.tp / (self.tp + self.fp + 1e-8) # Avoid division by zero
95+
recall_per_class = self.tp / (self.tp + self.fn + 1e-8) # Avoid division by zero
96+
f1_per_class = 2 * precision_per_class * recall_per_class / (
97+
precision_per_class + recall_per_class + 1e-8) # Avoid division by zero
98+
99+
# Take the average of F1 scores across all classes
100+
f1_score = torch.mean(f1_per_class)
101+
return f1_score
102+
49103
def forward(self, preds, target):
50104
"""
51-
Update the variables with predictions and true labels.
105+
Update the True Positives, False Positives, and False Negatives, and compute the F1 score.
106+
107+
This method computes the F1 score based on the predictions and true labels. It can compute either the
108+
macro-averaged or micro-averaged F1 score, depending on the `macro_averaging` flag.
52109
53110
Parameters
54111
----------
55112
preds : torch.Tensor
56-
Predicted logits (shape: [batch_size, num_classes]).
113+
Predicted logits or class indices (shape: [batch_size, num_classes]).
114+
These logits are typically the output of a softmax or sigmoid activation.
57115
58116
target : torch.Tensor
59-
True labels (shape: [batch_size]).
117+
True labels (shape: [batch_size]), where each element is an integer representing the true class.
118+
119+
Returns
120+
-------
121+
torch.Tensor
122+
The computed F1 score (either micro or macro, based on `macro_averaging`).
60123
"""
61124
preds = torch.argmax(preds, dim=1)
62125

@@ -66,10 +129,11 @@ def forward(self, preds, target):
66129
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
67130
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
68131

69-
f1_score = (
70-
2
71-
* torch.sum(self.tp)
72-
/ (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn))
73-
)
132+
if self.macro_averaging:
133+
# Calculate Macro F1 score
134+
f1_score = self._macro_F1()
135+
else:
136+
# Calculate Micro F1 score
137+
f1_score = self._micro_F1()
74138

75-
return f1_score
139+
return f1_score

0 commit comments

Comments
 (0)