Skip to content

Commit 9ec9d7b

Browse files
committed
added missing attribute retunrmetric to F1
1 parent c29b585 commit 9ec9d7b

File tree

1 file changed

+35
-16
lines changed
  • CollaborativeCoding/metrics

1 file changed

+35
-16
lines changed

CollaborativeCoding/metrics/F1.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
import torch.nn as nn
34

@@ -52,13 +53,14 @@ def __init__(self, num_classes, macro_averaging=False):
5253

5354
self.num_classes = num_classes
5455
self.macro_averaging = macro_averaging
55-
56+
self.y_true = []
57+
self.y_pred = []
5658
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
5759
self.tp = torch.zeros(num_classes)
5860
self.fp = torch.zeros(num_classes)
5961
self.fn = torch.zeros(num_classes)
6062

61-
def _micro_F1(self):
63+
def _micro_F1(self, target, preds):
6264
"""
6365
Compute the Micro F1 score by aggregating TP, FP, and FN across all classes.
6466
@@ -69,6 +71,11 @@ def _micro_F1(self):
6971
torch.Tensor
7072
The micro-averaged F1 score.
7173
"""
74+
for i in range(self.num_classes):
75+
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
76+
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
77+
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
78+
7279
tp = torch.sum(self.tp)
7380
fp = torch.sum(self.fp)
7481
fn = torch.sum(self.fn)
@@ -81,7 +88,7 @@ def _micro_F1(self):
8188
) # Avoid division by zero
8289
return f1
8390

84-
def _macro_F1(self):
91+
def _macro_F1(self, target, preds):
8592
"""
8693
Compute the Macro F1 score by calculating the F1 score per class and averaging.
8794
@@ -93,6 +100,12 @@ def _macro_F1(self):
93100
torch.Tensor
94101
The macro-averaged F1 score.
95102
"""
103+
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
104+
for i in range(self.num_classes):
105+
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
106+
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
107+
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
108+
96109
precision_per_class = self.tp / (
97110
self.tp + self.fp + 1e-8
98111
) # Avoid division by zero
@@ -133,18 +146,24 @@ def forward(self, preds, target):
133146
The computed F1 score (either micro or macro, based on `macro_averaging`).
134147
"""
135148
preds = torch.argmax(preds, dim=-1)
149+
self.y_true.append(target)
150+
self.y_pred.append(preds)
151+
152+
def __returnmetric__(self):
153+
if self.y_true == [] or self.y_pred == []:
154+
return np.nan
155+
if isinstance(self.y_true, list):
156+
if len(self.y_true) == 1:
157+
self.y_true = self.y_true[0]
158+
self.y_pred = self.y_pred[0]
159+
else:
160+
self.y_true = torch.cat(self.y_true)
161+
self.y_pred = torch.cat(self.y_pred)
162+
return self._micro_F1(self.y_true, self.y_pred) if not self.macro_averaging else self._macro_F1(self.y_true, self.y_pred)
163+
164+
def __reset__(self):
165+
self.y_true = []
166+
self.y_pred = []
167+
return None
136168

137-
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
138-
for i in range(self.num_classes):
139-
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
140-
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
141-
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
142169

143-
if self.macro_averaging:
144-
# Calculate Macro F1 score
145-
f1_score = self._macro_F1()
146-
else:
147-
# Calculate Micro F1 score
148-
f1_score = self._micro_F1()
149-
150-
return f1_score

0 commit comments

Comments
 (0)