Skip to content

Commit cd1e086

Browse files
committed
Add Recall metric
1 parent 3f79234 commit cd1e086

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

utils/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
__all__ = ["EntropyPrediction"]
1+
__all__ = ["EntropyPrediction", "Recall"]
22

33
from .EntropyPred import EntropyPrediction
4+
from .recall import Recall

utils/metrics/recall.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
def one_hot_encode(y_true, num_classes):
6+
"""One-hot encode the target tensor.
7+
8+
Args
9+
----
10+
y_true : torch.Tensor
11+
Target tensor.
12+
num_classes : int
13+
Number of classes in the dataset.
14+
15+
Returns
16+
-------
17+
torch.Tensor
18+
One-hot encoded tensor.
19+
"""
20+
21+
y_onehot = torch.zeros(y_true.size(0), num_classes)
22+
y_onehot.scatter_(1, y_true.unsqueeze(1), 1)
23+
return y_onehot
24+
25+
26+
class Recall(nn.Module):
27+
def __init__(self, num_classes):
28+
super().__init__()
29+
30+
self.num_classes = num_classes
31+
32+
def forward(self, y_true, y_pred):
33+
true_onehot = one_hot_encode(y_true, self.num_classes)
34+
pred_onehot = one_hot_encode(y_pred, self.num_classes)
35+
36+
true_positives = (true_onehot * pred_onehot).sum()
37+
38+
false_negatives = torch.sum(~pred_onehot[true_onehot.bool()].bool())
39+
40+
recall = true_positives / (true_positives + false_negatives)
41+
42+
return recall
43+
44+
45+
def test_recall():
46+
recall = Recall(7)
47+
48+
y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
49+
y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6])
50+
51+
recall_score = recall(y_true, y_pred)
52+
53+
assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), f"Recall Score: {recall_score.item()}"
54+
55+
56+
def test_one_hot_encode():
57+
num_classes = 7
58+
59+
y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
60+
y_onehot = one_hot_encode(y_true, num_classes)
61+
62+
assert y_onehot.shape == (7, 7), f"Shape: {y_onehot.shape}"

0 commit comments

Comments
 (0)