Skip to content

Commit bf8a09f

Browse files
committed
Update recall metric with macro/micro averaging
1 parent 64fac10 commit bf8a09f

File tree

1 file changed

+53
-13
lines changed

1 file changed

+53
-13
lines changed

utils/metrics/recall.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import torch.nn as nn
33

44

5-
def one_hot_encode(y_true, num_classes):
5+
def one_hot_encode(vec, num_classes):
66
"""One-hot encode the target tensor.
77
88
Args
99
----
10-
y_true : torch.Tensor
10+
vec : torch.Tensor
1111
Target tensor.
1212
num_classes : int
1313
Number of classes in the dataset.
@@ -18,25 +18,65 @@ def one_hot_encode(y_true, num_classes):
1818
One-hot encoded tensor.
1919
"""
2020

21-
y_onehot = torch.zeros(y_true.size(0), num_classes)
22-
y_onehot.scatter_(1, y_true.unsqueeze(1), 1)
23-
return y_onehot
21+
onehot = torch.zeros(vec.size(0), num_classes)
22+
onehot.scatter_(1, vec.unsqueeze(1), 1)
23+
return onehot
2424

2525

2626
class Recall(nn.Module):
27-
def __init__(self, num_classes):
28-
super().__init__()
27+
"""
28+
Recall metric.
29+
30+
Args
31+
----
32+
num_classes : int
33+
Number of classes in the dataset.
34+
macro_averaging : bool
35+
If True, calculate the recall for each class and return the average.
36+
If False, calculate the recall for the entire dataset.
2937
38+
Methods
39+
-------
40+
forward(y_true, y_pred)
41+
Compute the recall metric.
42+
43+
Examples
44+
--------
45+
>>> y_true = torch.tensor([0, 1, 2, 3, 4])
46+
>>> y_pred = torch.randn(5, 5).argmax(dim=-1)
47+
>>> recall = Recall(num_classes=5)
48+
>>> recall(y_true, y_pred)
49+
0.2
50+
>>> recall = Recall(num_classes=5, macro_averaging=True)
51+
>>> recall(y_true, y_pred)
52+
0.2
53+
"""
54+
55+
def __init__(self, num_classes, macro_averaging=False):
56+
super().__init__()
3057
self.num_classes = num_classes
58+
self.macro_averaging = macro_averaging
3159

32-
def forward(self, y_true, y_pred):
33-
true_onehot = one_hot_encode(y_true, self.num_classes)
34-
pred_onehot = one_hot_encode(y_pred, self.num_classes)
60+
def forward(self, true, logits):
61+
pred = logits.argmax(dim=-1)
62+
y_true = one_hot_encode(true, self.num_classes)
63+
y_pred = one_hot_encode(pred, self.num_classes)
3564

36-
true_positives = (true_onehot * pred_onehot).sum()
65+
if self.macro_averaging:
66+
recall = 0
67+
for i in range(self.num_classes):
68+
tp = (y_true[:, i] * y_pred[:, i]).sum()
69+
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
70+
recall += tp / (tp + fn)
71+
recall /= self.num_classes
72+
else:
73+
recall = self.__compute(y_true, y_pred)
3774

38-
false_negatives = torch.sum(~pred_onehot[true_onehot.bool()].bool())
75+
return recall
3976

40-
recall = true_positives / (true_positives + false_negatives)
77+
def __compute(self, y_true, y_pred):
78+
true_positives = (y_true * y_pred).sum()
79+
false_negatives = torch.sum(~y_pred[y_true.bool()].bool())
4180

81+
recall = true_positives / (true_positives + false_negatives)
4282
return recall

0 commit comments

Comments
 (0)