Skip to content

Commit 2be6ccf

Browse files
committed
Update recall to store metrics on the go
1 parent eab0b08 commit 2be6ccf

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

CollaborativeCoding/metrics/recall.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
import torch.nn as nn
34

@@ -57,26 +58,49 @@ def __init__(self, num_classes, macro_averaging=False):
5758
self.num_classes = num_classes
5859
self.macro_averaging = macro_averaging
5960

61+
self.__y_true = []
62+
self.__y_pred = []
63+
6064
def forward(self, true, logits):
6165
pred = logits.argmax(dim=-1)
6266
y_true = one_hot_encode(true, self.num_classes)
6367
y_pred = one_hot_encode(pred, self.num_classes)
6468

69+
self.__y_true.append(y_true)
70+
self.__y_pred.append(y_pred)
71+
72+
def compute(self, y_true, y_pred):
6573
if self.macro_averaging:
66-
recall = 0
67-
for i in range(self.num_classes):
68-
tp = (y_true[:, i] * y_pred[:, i]).sum()
69-
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
70-
recall += tp / (tp + fn)
71-
recall /= self.num_classes
72-
else:
73-
recall = self.__compute(y_true, y_pred)
74+
return self.__compute_macro_averaging(y_true, y_pred)
75+
76+
return self.__compute_micro_averaging(y_true, y_pred)
77+
78+
def __compute_macro_averaging(self, y_true, y_pred):
79+
recall = 0
80+
for i in range(self.num_classes):
81+
tp = (y_true[:, i] * y_pred[:, i]).sum()
82+
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
83+
recall += tp / (tp + fn)
84+
recall /= self.num_classes
7485

7586
return recall
7687

77-
def __compute(self, y_true, y_pred):
88+
def __compute_micro_averaging(self, y_true, y_pred):
7889
true_positives = (y_true * y_pred).sum()
7990
false_negatives = torch.sum(~y_pred[y_true.bool()].bool())
8091

8192
recall = true_positives / (true_positives + false_negatives)
8293
return recall
94+
95+
def __returnmetric__(self):
96+
if len(self.__y_true) == 0 and len(self.__y_pred) == 0:
97+
return np.nan
98+
99+
y_true = torch.cat(self.__y_true, dim=0)
100+
y_pred = torch.cat(self.__y_pred, dim=0)
101+
102+
return self.compute(y_true, y_pred)
103+
104+
def __reset__(self):
105+
self.__y_true = []
106+
self.__y_pred = []

0 commit comments

Comments
 (0)