Skip to content

Commit 54f3883

Browse files
committed
Added micro/macro averaging option to MetricsWrapper and as commandline argument
1 parent b93ee66 commit 54f3883

File tree

4 files changed

+78
-11
lines changed

4 files changed

+78
-11
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def main():
5858
transform=augmentations,
5959
)
6060

61-
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
61+
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes, macro_averaging=args.macro_averaging)
6262

6363
# Find the shape of the data, if is 2D, add a channel dimension
6464
data_shape = traindata[0][0].shape

utils/arg_parser.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ def get_args():
6363
nargs="+",
6464
help="Which metric to use for evaluation",
6565
)
66+
parser.add_argument(
67+
"--macro_averaging",
68+
action="store_true",
69+
help="If the flag is included, the metrics will be calculated using macro averaging.",
70+
)
71+
6672

6773
# Training specific values
6874
parser.add_argument(
@@ -93,6 +99,6 @@ def get_args():
9399
parser.add_argument(
94100
"--dry_run",
95101
action="store_true",
96-
help="If true, the code will not run the training loop.",
102+
help="If the flag is included, the code will not run the training loop.",
97103
)
98104
return parser.parse_args()

utils/load_metric.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ class MetricWrapper(nn.Module):
4747
"""
4848

4949

50-
def __init__(self, *metrics, num_classes):
50+
def __init__(self, *metrics, num_classes, macro_averaging=False):
5151

5252
super().__init__()
5353
self.metrics = {}
5454
self.num_classes = num_classes
55+
self.macro_averaging = macro_averaging
5556

5657
for metric in metrics:
5758
self.metrics[metric] = self._get_metric(metric)
@@ -77,13 +78,13 @@ def _get_metric(self, key):
7778
case "entropy":
7879
return EntropyPrediction(num_classes=self.num_classes)
7980
case "f1":
80-
return F1Score(num_classes=self.num_classes)
81+
return F1Score(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
8182
case "recall":
82-
return Recall(num_classes=self.num_classes)
83+
return Recall(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
8384
case "precision":
84-
return Precision(num_classes=self.num_classes)
85+
return Precision(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
8586
case "accuracy":
86-
return Accuracy(num_classes=self.num_classes)
87+
return Accuracy(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
8788
case _:
8889
raise ValueError(f"Metric {key} not supported")
8990

utils/metrics/accuracy.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44

55
class Accuracy(nn.Module):
6-
def __init__(self, num_classes):
6+
def __init__(self, num_classes, macro_averaging=False):
77
super().__init__()
88
self.num_classes = num_classes
9-
9+
self.macro_averaging = macro_averaging
10+
1011
def forward(self, y_true, y_pred):
1112
"""
1213
Compute the accuracy of the model.
@@ -23,12 +24,71 @@ def forward(self, y_true, y_pred):
2324
float
2425
Accuracy score.
2526
"""
27+
if y_pred.dim() > 1:
28+
y_pred = y_pred.argmax(dim=1)
29+
if self.macro_averaging:
30+
return self._macro_acc(y_true, y_pred)
31+
else:
32+
return self._micro_acc(y_true, y_pred)
33+
34+
def _macro_acc(self, y_true, y_pred):
35+
"""
36+
Compute the macro-average accuracy.
37+
38+
Parameters
39+
----------
40+
y_true : torch.Tensor
41+
True labels.
42+
y_pred : torch.Tensor
43+
Predicted labels.
44+
45+
Returns
46+
-------
47+
float
48+
Macro-average accuracy score.
49+
"""
50+
y_true, y_pred = y_true.flatten(), y_pred.flatten() # Ensure 1D shape
51+
52+
classes = torch.unique(y_true) # Find unique class labels
53+
acc_per_class = []
54+
55+
for c in classes:
56+
mask = (y_true == c) # Mask for class c
57+
acc = (y_pred[mask] == y_true[mask]).float().mean() # Accuracy for class c
58+
acc_per_class.append(acc)
59+
60+
macro_acc = torch.stack(acc_per_class).mean().item() # Average across classes
61+
return macro_acc
62+
63+
def _micro_acc(self, y_true, y_pred):
64+
"""
65+
Compute the micro-average accuracy.
66+
67+
Parameters
68+
----------
69+
y_true : torch.Tensor
70+
True labels.
71+
y_pred : torch.Tensor
72+
Predicted labels.
73+
74+
Returns
75+
-------
76+
float
77+
Micro-average accuracy score.
78+
"""
2679
return (y_true == y_pred).float().mean().item()
2780

2881

2982
if __name__ == "__main__":
83+
accuracy = Accuracy(5)
84+
macro_accuracy = Accuracy(5, macro_averaging=True)
85+
3086
y_true = torch.tensor([0, 3, 2, 3, 4])
3187
y_pred = torch.tensor([0, 1, 2, 3, 4])
32-
33-
accuracy = Accuracy()
3488
print(accuracy(y_true, y_pred))
89+
print(macro_accuracy(y_true, y_pred))
90+
91+
y_true = torch.tensor([0, 3, 2, 3, 4])
92+
y_onehot_pred = torch.tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]])
93+
print(accuracy(y_true, y_onehot_pred))
94+
print(macro_accuracy(y_true, y_onehot_pred))

0 commit comments

Comments
 (0)