|
| 1 | +import numpy as np |
1 | 2 | import torch |
2 | 3 | import torch.nn as nn |
3 | 4 |
|
@@ -57,26 +58,49 @@ def __init__(self, num_classes, macro_averaging=False): |
57 | 58 | self.num_classes = num_classes |
58 | 59 | self.macro_averaging = macro_averaging |
59 | 60 |
|
| 61 | + self.__y_true = [] |
| 62 | + self.__y_pred = [] |
| 63 | + |
60 | 64 | def forward(self, true, logits): |
61 | 65 | pred = logits.argmax(dim=-1) |
62 | 66 | y_true = one_hot_encode(true, self.num_classes) |
63 | 67 | y_pred = one_hot_encode(pred, self.num_classes) |
64 | 68 |
|
| 69 | + self.__y_true.append(y_true) |
| 70 | + self.__y_pred.append(y_pred) |
| 71 | + |
| 72 | + def compute(self, y_true, y_pred): |
65 | 73 | if self.macro_averaging: |
66 | | - recall = 0 |
67 | | - for i in range(self.num_classes): |
68 | | - tp = (y_true[:, i] * y_pred[:, i]).sum() |
69 | | - fn = torch.sum(~y_pred[y_true[:, i].bool()].bool()) |
70 | | - recall += tp / (tp + fn) |
71 | | - recall /= self.num_classes |
72 | | - else: |
73 | | - recall = self.__compute(y_true, y_pred) |
| 74 | + return self.__compute_macro_averaging(y_true, y_pred) |
| 75 | + |
| 76 | + return self.__compute_micro_averaging(y_true, y_pred) |
| 77 | + |
| 78 | + def __compute_macro_averaging(self, y_true, y_pred): |
| 79 | + recall = 0 |
| 80 | + for i in range(self.num_classes): |
| 81 | + tp = (y_true[:, i] * y_pred[:, i]).sum() |
| 82 | + fn = torch.sum(~y_pred[y_true[:, i].bool()].bool()) |
| 83 | + recall += tp / (tp + fn) |
| 84 | + recall /= self.num_classes |
74 | 85 |
|
75 | 86 | return recall |
76 | 87 |
|
77 | | - def __compute(self, y_true, y_pred): |
| 88 | + def __compute_micro_averaging(self, y_true, y_pred): |
78 | 89 | true_positives = (y_true * y_pred).sum() |
79 | 90 | false_negatives = torch.sum(~y_pred[y_true.bool()].bool()) |
80 | 91 |
|
81 | 92 | recall = true_positives / (true_positives + false_negatives) |
82 | 93 | return recall |
| 94 | + |
| 95 | + def __returnmetric__(self): |
| 96 | + if len(self.__y_true) == 0 and len(self.__y_pred) == 0: |
| 97 | + return np.nan |
| 98 | + |
| 99 | + y_true = torch.cat(self.__y_true, dim=0) |
| 100 | + y_pred = torch.cat(self.__y_pred, dim=0) |
| 101 | + |
| 102 | + return self.compute(y_true, y_pred) |
| 103 | + |
| 104 | + def __reset__(self): |
| 105 | + self.__y_true = [] |
| 106 | + self.__y_pred = [] |
0 commit comments