Skip to content

Commit c2c72ee

Browse files
committed
Merge branch 'main' into christian/update-dataloader-recall
2 parents 0de568e + 331734d commit c2c72ee

File tree

2 files changed

+59
-25
lines changed

2 files changed

+59
-25
lines changed

CollaborativeCoding/metrics/F1.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
import torch.nn as nn
34

@@ -52,13 +53,14 @@ def __init__(self, num_classes, macro_averaging=False):
5253

5354
self.num_classes = num_classes
5455
self.macro_averaging = macro_averaging
55-
56+
self.y_true = []
57+
self.y_pred = []
5658
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
5759
self.tp = torch.zeros(num_classes)
5860
self.fp = torch.zeros(num_classes)
5961
self.fn = torch.zeros(num_classes)
6062

61-
def _micro_F1(self):
63+
def _micro_F1(self, target, preds):
6264
"""
6365
Compute the Micro F1 score by aggregating TP, FP, and FN across all classes.
6466
@@ -69,6 +71,11 @@ def _micro_F1(self):
6971
torch.Tensor
7072
The micro-averaged F1 score.
7173
"""
74+
for i in range(self.num_classes):
75+
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
76+
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
77+
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
78+
7279
tp = torch.sum(self.tp)
7380
fp = torch.sum(self.fp)
7481
fn = torch.sum(self.fn)
@@ -81,7 +88,7 @@ def _micro_F1(self):
8188
) # Avoid division by zero
8289
return f1
8390

84-
def _macro_F1(self):
91+
def _macro_F1(self, target, preds):
8592
"""
8693
Compute the Macro F1 score by calculating the F1 score per class and averaging.
8794
@@ -93,6 +100,12 @@ def _macro_F1(self):
93100
torch.Tensor
94101
The macro-averaged F1 score.
95102
"""
103+
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
104+
for i in range(self.num_classes):
105+
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
106+
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
107+
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
108+
96109
precision_per_class = self.tp / (
97110
self.tp + self.fp + 1e-8
98111
) # Avoid division by zero
@@ -133,18 +146,24 @@ def forward(self, preds, target):
133146
The computed F1 score (either micro or macro, based on `macro_averaging`).
134147
"""
135148
preds = torch.argmax(preds, dim=-1)
149+
self.y_true.append(target)
150+
self.y_pred.append(preds)
151+
152+
def __returnmetric__(self):
153+
if self.y_true == [] or self.y_pred == []:
154+
return np.nan
155+
if isinstance(self.y_true, list):
156+
if len(self.y_true) == 1:
157+
self.y_true = self.y_true[0]
158+
self.y_pred = self.y_pred[0]
159+
else:
160+
self.y_true = torch.cat(self.y_true)
161+
self.y_pred = torch.cat(self.y_pred)
162+
return self._micro_F1(self.y_true, self.y_pred) if not self.macro_averaging else self._macro_F1(self.y_true, self.y_pred)
163+
164+
def __reset__(self):
165+
self.y_true = []
166+
self.y_pred = []
167+
return None
136168

137-
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
138-
for i in range(self.num_classes):
139-
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
140-
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
141-
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
142169

143-
if self.macro_averaging:
144-
# Calculate Macro F1 score
145-
f1_score = self._macro_F1()
146-
else:
147-
# Calculate Micro F1 score
148-
f1_score = self._micro_F1()
149-
150-
return f1_score

tests/test_metrics.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,32 @@ def test_recall():
7474
def test_f1score():
7575
import torch
7676

77-
f1_metric = F1Score(num_classes=3)
78-
preds = torch.tensor(
79-
[[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]]
80-
)
77+
# Example case with known output
78+
y_true = torch.tensor([0, 1, 2, 2, 1, 0]) # True labels
79+
y_pred = torch.tensor([0, 1, 1, 2, 0, 0]) # Predicted labels
80+
81+
# Create F1Score object for micro and macro averaging
82+
f1_micro = F1Score(num_classes=3, macro_averaging=False)
83+
f1_macro = F1Score(num_classes=3, macro_averaging=True)
84+
85+
# Update F1 score with predictions
86+
f1_micro(y_true, y_pred)
87+
f1_macro(y_true, y_pred)
88+
89+
# Get F1 scores
90+
micro_f1_score = f1_micro.__returnmetric__()
91+
macro_f1_score = f1_macro.__returnmetric__()
92+
93+
# Check if outputs are tensors
94+
assert isinstance(micro_f1_score, torch.Tensor), "Micro F1 score should be a tensor."
95+
assert isinstance(macro_f1_score, torch.Tensor), "Macro F1 score should be a tensor."
8196

82-
target = torch.tensor([0, 1, 0, 2])
97+
# Check that F1 scores are between 0 and 1
98+
assert 0 <= micro_f1_score.item() <= 1, "Micro F1 score should be between 0 and 1."
99+
assert 0 <= macro_f1_score.item() <= 1, "Macro F1 score should be between 0 and 1."
83100

84-
f1_metric(preds, target)
85-
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
86-
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
87-
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
101+
print(f"Micro F1 Score: {micro_f1_score.item()}")
102+
print(f"Macro F1 Score: {macro_f1_score.item()}")
88103

89104

90105
def test_precision():

0 commit comments

Comments
 (0)