Skip to content

Commit 27f120c

Browse files
authored
Merge pull request #63 from SFI-Visual-Intelligence/Jan-metrics
Merging the metrics updates with dataloader updates
2 parents b7bffa3 + 4071181 commit 27f120c

File tree

5 files changed

+161
-41
lines changed

5 files changed

+161
-41
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def main():
4646
val_size=args.val_size,
4747
)
4848

49-
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
49+
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes, macro_averaging=args.macro_averaging)
5050

5151
# Find the shape of the data, if is 2D, add a channel dimension
5252
data_shape = traindata[0][0].shape

utils/arg_parser.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ def get_args():
6868
nargs="+",
6969
help="Which metric to use for evaluation",
7070
)
71+
parser.add_argument(
72+
"--macro_averaging",
73+
action="store_true",
74+
help="If the flag is included, the metrics will be calculated using macro averaging.",
75+
)
76+
7177

7278
parser.add_argument("--imagesize", type=int, default=28, help="Imagesize")
7379

@@ -108,7 +114,7 @@ def get_args():
108114
parser.add_argument(
109115
"--dry_run",
110116
action="store_true",
111-
help="If true, the code will not run the training loop.",
117+
help="If the flag is included, the code will not run the training loop.",
112118
)
113119
args = parser.parse_args()
114120

utils/load_metric.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ class MetricWrapper(nn.Module):
4545
{'entropy': [], 'f1': [], 'precision': []}
4646
"""
4747

48-
def __init__(self, *metrics, num_classes):
48+
def __init__(self, *metrics, num_classes, macro_averaging=False):
4949
super().__init__()
5050
self.metrics = {}
5151
self.num_classes = num_classes
52+
self.macro_averaging = macro_averaging
5253

5354
for metric in metrics:
5455
self.metrics[metric] = self._get_metric(metric)
@@ -74,13 +75,13 @@ def _get_metric(self, key):
7475
case "entropy":
7576
return EntropyPrediction(num_classes=self.num_classes)
7677
case "f1":
77-
return F1Score(num_classes=self.num_classes)
78+
return F1Score(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
7879
case "recall":
79-
return Recall(num_classes=self.num_classes)
80+
return Recall(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
8081
case "precision":
81-
return Precision(num_classes=self.num_classes)
82+
return Precision(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
8283
case "accuracy":
83-
return Accuracy(num_classes=self.num_classes)
84+
return Accuracy(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
8485
case _:
8586
raise ValueError(f"Metric {key} not supported")
8687

utils/metrics/F1.py

Lines changed: 83 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,39 @@
44

55
class F1Score(nn.Module):
66
"""
7-
F1 Score implementation with direct averaging inside the compute method.
7+
F1 Score implementation with support for both macro and micro averaging.
8+
9+
This class computes the F1 score during training using either macro or micro averaging.
10+
The F1 score is calculated based on the true positives (TP), false positives (FP),
11+
and false negatives (FN) for each class.
812
913
Parameters
1014
----------
1115
num_classes : int
12-
Number of classes.
16+
The number of classes in the classification task.
17+
18+
macro_averaging : bool, optional, default=False
19+
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
1320
1421
Attributes
1522
----------
1623
num_classes : int
17-
The number of classes.
24+
The number of classes in the classification task.
1825
1926
tp : torch.Tensor
20-
Tensor for True Positives (TP) for each class.
27+
Tensor storing the count of True Positives (TP) for each class.
2128
2229
fp : torch.Tensor
23-
Tensor for False Positives (FP) for each class.
30+
Tensor storing the count of False Positives (FP) for each class.
2431
2532
fn : torch.Tensor
26-
Tensor for False Negatives (FN) for each class.
33+
Tensor storing the count of False Negatives (FN) for each class.
34+
35+
macro_averaging : bool
36+
A flag indicating whether to compute the macro-averaged F1 score or not.
2737
"""
2838

29-
def __init__(self, num_classes):
39+
def __init__(self, num_classes, macro_averaging=False):
3040
"""
3141
Initializes the F1Score object, setting up the necessary state variables.
3242
@@ -35,28 +45,81 @@ def __init__(self, num_classes):
3545
num_classes : int
3646
The number of classes in the classification task.
3747
48+
macro_averaging : bool, optional, default=False
49+
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
3850
"""
39-
4051
super().__init__()
4152

4253
self.num_classes = num_classes
54+
self.macro_averaging = macro_averaging
4355

44-
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
56+
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
4557
self.tp = torch.zeros(num_classes)
4658
self.fp = torch.zeros(num_classes)
4759
self.fn = torch.zeros(num_classes)
4860

49-
def update(self, preds, target):
61+
def _micro_F1(self):
62+
"""
63+
Compute the Micro F1 score by aggregating TP, FP, and FN across all classes.
64+
65+
Micro F1 score is calculated globally by considering all predictions together, regardless of class.
66+
67+
Returns
68+
-------
69+
torch.Tensor
70+
The micro-averaged F1 score.
5071
"""
51-
Update the variables with predictions and true labels.
72+
tp = torch.sum(self.tp)
73+
fp = torch.sum(self.fp)
74+
fn = torch.sum(self.fn)
75+
76+
precision = tp / (tp + fp + 1e-8) # Avoid division by zero
77+
recall = tp / (tp + fn + 1e-8) # Avoid division by zero
78+
79+
f1 = 2 * precision * recall / (precision + recall + 1e-8) # Avoid division by zero
80+
return f1
81+
82+
def _macro_F1(self):
83+
"""
84+
Compute the Macro F1 score by calculating the F1 score per class and averaging.
85+
86+
Macro F1 score is calculated as the average of per-class F1 scores. This approach treats all classes equally,
87+
regardless of their frequency.
88+
89+
Returns
90+
-------
91+
torch.Tensor
92+
The macro-averaged F1 score.
93+
"""
94+
precision_per_class = self.tp / (self.tp + self.fp + 1e-8) # Avoid division by zero
95+
recall_per_class = self.tp / (self.tp + self.fn + 1e-8) # Avoid division by zero
96+
f1_per_class = 2 * precision_per_class * recall_per_class / (
97+
precision_per_class + recall_per_class + 1e-8) # Avoid division by zero
98+
99+
# Take the average of F1 scores across all classes
100+
f1_score = torch.mean(f1_per_class)
101+
return f1_score
102+
103+
def forward(self, preds, target):
104+
"""
105+
Update the True Positives, False Positives, and False Negatives, and compute the F1 score.
106+
107+
This method computes the F1 score based on the predictions and true labels. It can compute either the
108+
macro-averaged or micro-averaged F1 score, depending on the `macro_averaging` flag.
52109
53110
Parameters
54111
----------
55112
preds : torch.Tensor
56-
Predicted logits (shape: [batch_size, num_classes]).
113+
Predicted logits or class indices (shape: [batch_size, num_classes]).
114+
These logits are typically the output of a softmax or sigmoid activation.
57115
58116
target : torch.Tensor
59-
True labels (shape: [batch_size]).
117+
True labels (shape: [batch_size]), where each element is an integer representing the true class.
118+
119+
Returns
120+
-------
121+
torch.Tensor
122+
The computed F1 score (either micro or macro, based on `macro_averaging`).
60123
"""
61124
preds = torch.argmax(preds, dim=1)
62125

@@ -66,21 +129,11 @@ def update(self, preds, target):
66129
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
67130
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
68131

69-
def compute(self):
70-
"""
71-
Compute the F1 score.
132+
if self.macro_averaging:
133+
# Calculate Macro F1 score
134+
f1_score = self._macro_F1()
135+
else:
136+
# Calculate Micro F1 score
137+
f1_score = self._micro_F1()
72138

73-
Returns
74-
-------
75-
torch.Tensor
76-
The computed F1 score.
77-
"""
78-
79-
# Compute F1 score based on the specified averaging method
80-
f1_score = (
81-
2
82-
* torch.sum(self.tp)
83-
/ (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn))
84-
)
85-
86-
return f1_score
139+
return f1_score

utils/metrics/accuracy.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44

55
class Accuracy(nn.Module):
6-
def __init__(self, num_classes):
6+
def __init__(self, num_classes, macro_averaging=False):
77
super().__init__()
88
self.num_classes = num_classes
9-
9+
self.macro_averaging = macro_averaging
10+
1011
def forward(self, y_true, y_pred):
1112
"""
1213
Compute the accuracy of the model.
@@ -23,12 +24,71 @@ def forward(self, y_true, y_pred):
2324
float
2425
Accuracy score.
2526
"""
27+
if y_pred.dim() > 1:
28+
y_pred = y_pred.argmax(dim=1)
29+
if self.macro_averaging:
30+
return self._macro_acc(y_true, y_pred)
31+
else:
32+
return self._micro_acc(y_true, y_pred)
33+
34+
def _macro_acc(self, y_true, y_pred):
35+
"""
36+
Compute the macro-average accuracy.
37+
38+
Parameters
39+
----------
40+
y_true : torch.Tensor
41+
True labels.
42+
y_pred : torch.Tensor
43+
Predicted labels.
44+
45+
Returns
46+
-------
47+
float
48+
Macro-average accuracy score.
49+
"""
50+
y_true, y_pred = y_true.flatten(), y_pred.flatten() # Ensure 1D shape
51+
52+
classes = torch.unique(y_true) # Find unique class labels
53+
acc_per_class = []
54+
55+
for c in classes:
56+
mask = (y_true == c) # Mask for class c
57+
acc = (y_pred[mask] == y_true[mask]).float().mean() # Accuracy for class c
58+
acc_per_class.append(acc)
59+
60+
macro_acc = torch.stack(acc_per_class).mean().item() # Average across classes
61+
return macro_acc
62+
63+
def _micro_acc(self, y_true, y_pred):
64+
"""
65+
Compute the micro-average accuracy.
66+
67+
Parameters
68+
----------
69+
y_true : torch.Tensor
70+
True labels.
71+
y_pred : torch.Tensor
72+
Predicted labels.
73+
74+
Returns
75+
-------
76+
float
77+
Micro-average accuracy score.
78+
"""
2679
return (y_true == y_pred).float().mean().item()
2780

2881

2982
if __name__ == "__main__":
83+
accuracy = Accuracy(5)
84+
macro_accuracy = Accuracy(5, macro_averaging=True)
85+
3086
y_true = torch.tensor([0, 3, 2, 3, 4])
3187
y_pred = torch.tensor([0, 1, 2, 3, 4])
32-
33-
accuracy = Accuracy()
3488
print(accuracy(y_true, y_pred))
89+
print(macro_accuracy(y_true, y_pred))
90+
91+
y_true = torch.tensor([0, 3, 2, 3, 4])
92+
y_onehot_pred = torch.tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]])
93+
print(accuracy(y_true, y_onehot_pred))
94+
print(macro_accuracy(y_true, y_onehot_pred))

0 commit comments

Comments
 (0)