22import torch .nn as nn
33
44
5- def one_hot_encode (y_true , num_classes ):
5+ def one_hot_encode (vec , num_classes ):
66 """One-hot encode the target tensor.
77
88 Args
99 ----
10- y_true : torch.Tensor
10+ vec : torch.Tensor
1111 Target tensor.
1212 num_classes : int
1313 Number of classes in the dataset.
@@ -18,25 +18,65 @@ def one_hot_encode(y_true, num_classes):
1818 One-hot encoded tensor.
1919 """
2020
21- y_onehot = torch .zeros (y_true .size (0 ), num_classes )
22- y_onehot .scatter_ (1 , y_true .unsqueeze (1 ), 1 )
23- return y_onehot
21+ onehot = torch .zeros (vec .size (0 ), num_classes )
22+ onehot .scatter_ (1 , vec .unsqueeze (1 ), 1 )
23+ return onehot
2424
2525
2626class Recall (nn .Module ):
27- def __init__ (self , num_classes ):
28- super ().__init__ ()
27+ """
28+ Recall metric.
29+
30+ Args
31+ ----
32+ num_classes : int
33+ Number of classes in the dataset.
34+ macro_averaging : bool
35+ If True, calculate the recall for each class and return the average.
36+ If False, calculate the recall for the entire dataset.
2937
38+ Methods
39+ -------
40+ forward(y_true, y_pred)
41+ Compute the recall metric.
42+
43+ Examples
44+ --------
45+ >>> y_true = torch.tensor([0, 1, 2, 3, 4])
46+ >>> y_pred = torch.randn(5, 5).argmax(dim=-1)
47+ >>> recall = Recall(num_classes=5)
48+ >>> recall(y_true, y_pred)
49+ 0.2
50+ >>> recall = Recall(num_classes=5, macro_averaging=True)
51+ >>> recall(y_true, y_pred)
52+ 0.2
53+ """
54+
55+ def __init__ (self , num_classes , macro_averaging = False ):
56+ super ().__init__ ()
3057 self .num_classes = num_classes
58+ self .macro_averaging = macro_averaging
3159
32- def forward (self , y_true , y_pred ):
33- true_onehot = one_hot_encode (y_true , self .num_classes )
34- pred_onehot = one_hot_encode (y_pred , self .num_classes )
60+ def forward (self , true , logits ):
61+ pred = logits .argmax (dim = - 1 )
62+ y_true = one_hot_encode (true , self .num_classes )
63+ y_pred = one_hot_encode (pred , self .num_classes )
3564
36- true_positives = (true_onehot * pred_onehot ).sum ()
65+ if self .macro_averaging :
66+ recall = 0
67+ for i in range (self .num_classes ):
68+ tp = (y_true [:, i ] * y_pred [:, i ]).sum ()
69+ fn = torch .sum (~ y_pred [y_true [:, i ].bool ()].bool ())
70+ recall += tp / (tp + fn )
71+ recall /= self .num_classes
72+ else :
73+ recall = self .__compute (y_true , y_pred )
3774
38- false_negatives = torch . sum ( ~ pred_onehot [ true_onehot . bool ()]. bool ())
75+ return recall
3976
40- recall = true_positives / (true_positives + false_negatives )
77+ def __compute (self , y_true , y_pred ):
78+ true_positives = (y_true * y_pred ).sum ()
79+ false_negatives = torch .sum (~ y_pred [y_true .bool ()].bool ())
4180
81+ recall = true_positives / (true_positives + false_negatives )
4282 return recall
0 commit comments