Skip to content

Commit bda9024

Browse files
committed
Fixed macro averaging in recall
1 parent 533614d commit bda9024

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

CollaborativeCoding/metrics/recall.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,15 @@ def __compute_macro_averaging(self, y_true, y_pred):
7979
recall = 0
8080
for i in range(self.num_classes):
8181
tp = (y_true[:, i] * y_pred[:, i]).sum()
82-
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
82+
fn = (y_true[:, i] * (1 - y_pred[:, i])).sum()
8383
recall += tp / (tp + fn)
8484
recall /= self.num_classes
8585

8686
return recall
8787

8888
def __compute_micro_averaging(self, y_true, y_pred):
8989
true_positives = (y_true * y_pred).sum()
90-
false_negatives = torch.sum(~y_pred[y_true.bool()].bool())
90+
false_negatives = (y_true * (1 - y_pred)).sum()
9191

9292
recall = true_positives / (true_positives + false_negatives)
9393
return recall

0 commit comments

Comments
 (0)