|
6 | 6 | class F1Score(nn.Module): |
7 | 7 | """ |
8 | 8 | F1 Score implementation with support for both macro and micro averaging. |
9 | | -
|
10 | 9 | This class computes the F1 score during training using either macro or micro averaging. |
11 | | - The F1 score is calculated based on the true positives (TP), false positives (FP), |
12 | | - and false negatives (FN) for each class. |
13 | | -
|
14 | 10 | Parameters |
15 | 11 | ---------- |
16 | 12 | num_classes : int |
17 | 13 | The number of classes in the classification task. |
18 | 14 |
|
19 | | - macro_averaging : bool, optional, default=False |
| 15 | + macro_averaging : bool, default=False |
20 | 16 | If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score. |
21 | | -
|
22 | | - Attributes |
23 | | - ---------- |
24 | | - num_classes : int |
25 | | - The number of classes in the classification task. |
26 | | -
|
27 | | - tp : torch.Tensor |
28 | | - Tensor storing the count of True Positives (TP) for each class. |
29 | | -
|
30 | | - fp : torch.Tensor |
31 | | - Tensor storing the count of False Positives (FP) for each class. |
32 | | -
|
33 | | - fn : torch.Tensor |
34 | | - Tensor storing the count of False Negatives (FN) for each class. |
35 | | -
|
36 | | - macro_averaging : bool |
37 | | - A flag indicating whether to compute the macro-averaged F1 score or not. |
38 | 17 | """ |
39 | 18 |
|
40 | 19 | def __init__(self, num_classes, macro_averaging=False): |
41 | | - """ |
42 | | - Initializes the F1Score object, setting up the necessary state variables. |
43 | | -
|
44 | | - Parameters |
45 | | - ---------- |
46 | | - num_classes : int |
47 | | - The number of classes in the classification task. |
48 | | -
|
49 | | - macro_averaging : bool, optional, default=False |
50 | | - If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score. |
51 | | - """ |
52 | 20 | super().__init__() |
53 | | - |
54 | 21 | self.num_classes = num_classes |
55 | 22 | self.macro_averaging = macro_averaging |
56 | 23 | self.y_true = [] |
57 | 24 | self.y_pred = [] |
58 | | - # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN) |
59 | | - self.tp = torch.zeros(num_classes) |
60 | | - self.fp = torch.zeros(num_classes) |
61 | | - self.fn = torch.zeros(num_classes) |
62 | 25 |
|
63 | | - def _micro_F1(self, target, preds): |
| 26 | + def forward(self, target, preds): |
64 | 27 | """ |
65 | | - Compute the Micro F1 score by aggregating TP, FP, and FN across all classes. |
| 28 | + Stores predictions and targets for computing the F1 score. |
66 | 29 |
|
67 | | - Micro F1 score is calculated globally by considering all predictions together, regardless of class. |
| 30 | + Parameters |
| 31 | + ---------- |
| 32 | + preds : torch.Tensor |
| 33 | + Predicted logits (shape: [batch_size, num_classes]). |
| 34 | + target : torch.Tensor |
| 35 | + True labels (shape: [batch_size]). |
| 36 | + """ |
| 37 | + preds = torch.argmax(preds, dim=-1) # Convert logits to class indices |
| 38 | + self.y_true.append(target.detach()) |
| 39 | + if preds.dim() == 0: # Scalar (e.g., single class prediction) |
| 40 | + preds = preds.unsqueeze(0) # Add batch dimension |
| 41 | + self.y_pred.append(preds.detach()) |
| 42 | + |
| 43 | + def compute_f1(self): |
| 44 | + """ |
| 45 | + Computes the F1 score (Micro or Macro). |
68 | 46 |
|
69 | 47 | Returns |
70 | 48 | ------- |
71 | 49 | torch.Tensor |
72 | | - The micro-averaged F1 score. |
| 50 | + The computed F1 score. |
73 | 51 | """ |
74 | | - for i in range(self.num_classes): |
75 | | - self.tp[i] += torch.sum((preds == i) & (target == i)).float() |
76 | | - self.fp[i] += torch.sum((preds == i) & (target != i)).float() |
77 | | - self.fn[i] += torch.sum((preds != i) & (target == i)).float() |
| 52 | + if not self.y_true or not self.y_pred: # Check if empty |
| 53 | + return torch.tensor(np.nan) |
| 54 | + |
| 55 | + # Convert lists to tensors |
| 56 | + y_true = torch.cat(self.y_true) |
| 57 | + y_pred = torch.cat(self.y_pred) |
78 | 58 |
|
79 | | - tp = torch.sum(self.tp) |
80 | | - fp = torch.sum(self.fp) |
81 | | - fn = torch.sum(self.fn) |
| 59 | + return ( |
| 60 | + self._macro_F1(y_true, y_pred) |
| 61 | + if self.macro_averaging |
| 62 | + else self._micro_F1(y_true, y_pred) |
| 63 | + ) |
82 | 64 |
|
83 | | - precision = tp / (tp + fp + 1e-8) # Avoid division by zero |
84 | | - recall = tp / (tp + fn + 1e-8) # Avoid division by zero |
| 65 | + def _micro_F1(self, target, preds): |
| 66 | + """Computes Micro F1 Score (global TP, FP, FN).""" |
| 67 | + tp = torch.sum(preds == target).float() |
| 68 | + fp = torch.sum(preds != target).float() |
| 69 | + fn = fp # Since all errors are either FP or FN |
| 70 | + |
| 71 | + precision = tp / (tp + fp + 1e-8) |
| 72 | + recall = tp / (tp + fn + 1e-8) |
| 73 | + f1 = 2 * (precision * recall) / (precision + recall + 1e-8) |
85 | 74 |
|
86 | | - f1 = ( |
87 | | - 2 * precision * recall / (precision + recall + 1e-8) |
88 | | - ) # Avoid division by zero |
89 | 75 | return f1 |
90 | 76 |
|
91 | 77 | def _macro_F1(self, target, preds): |
92 | | - """ |
93 | | - Compute the Macro F1 score by calculating the F1 score per class and averaging. |
94 | | -
|
95 | | - Macro F1 score is calculated as the average of per-class F1 scores. This approach treats all classes equally, |
96 | | - regardless of their frequency. |
| 78 | + """Computes Macro F1 Score in a vectorized way (no loops).""" |
| 79 | + num_classes = self.num_classes |
| 80 | + target = target.long() # Ensure target is a LongTensor |
| 81 | + preds = preds.long() |
| 82 | + # Create one-hot encodings of the true and predicted labels |
| 83 | + target_one_hot = torch.nn.functional.one_hot(target, num_classes=num_classes) |
| 84 | + preds_one_hot = torch.nn.functional.one_hot(preds, num_classes=num_classes) |
97 | 85 |
|
98 | | - Returns |
99 | | - ------- |
100 | | - torch.Tensor |
101 | | - The macro-averaged F1 score. |
102 | | - """ |
103 | | - # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class |
104 | | - for i in range(self.num_classes): |
105 | | - self.tp[i] += torch.sum((preds == i) & (target == i)).float() |
106 | | - self.fp[i] += torch.sum((preds == i) & (target != i)).float() |
107 | | - self.fn[i] += torch.sum((preds != i) & (target == i)).float() |
108 | | - |
109 | | - precision_per_class = self.tp / ( |
110 | | - self.tp + self.fp + 1e-8 |
111 | | - ) # Avoid division by zero |
112 | | - recall_per_class = self.tp / ( |
113 | | - self.tp + self.fn + 1e-8 |
114 | | - ) # Avoid division by zero |
115 | | - f1_per_class = ( |
116 | | - 2 |
117 | | - * precision_per_class |
118 | | - * recall_per_class |
119 | | - / (precision_per_class + recall_per_class + 1e-8) |
120 | | - ) # Avoid division by zero |
121 | | - |
122 | | - # Take the average of F1 scores across all classes |
123 | | - f1_score = torch.mean(f1_per_class) |
124 | | - return f1_score |
125 | | - |
126 | | - def forward(self, preds, target): |
127 | | - """ |
| 86 | + # Compute TP, FP, FN for each class |
| 87 | + tp = torch.sum(target_one_hot * preds_one_hot, dim=0).float() |
| 88 | + fp = torch.sum(preds_one_hot * (1 - target_one_hot), dim=0).float() |
| 89 | + fn = torch.sum(target_one_hot * (1 - preds_one_hot), dim=0).float() |
128 | 90 |
|
129 | | - Update the True Positives, False Positives, and False Negatives, and compute the F1 score. |
| 91 | + # Compute precision and recall per class |
| 92 | + precision = tp / (tp + fp + 1e-8) |
| 93 | + recall = tp / (tp + fn + 1e-8) |
130 | 94 |
|
131 | | - This method computes the F1 score based on the predictions and true labels. It can compute either the |
132 | | - macro-averaged or micro-averaged F1 score, depending on the `macro_averaging` flag. |
| 95 | + # Compute per-class F1 score |
| 96 | + f1_per_class = 2 * (precision * recall) / (precision + recall + 1e-8) |
133 | 97 |
|
134 | | - Parameters |
135 | | - ---------- |
136 | | - preds : torch.Tensor |
137 | | - Predicted logits or class indices (shape: [batch_size, num_classes]). |
138 | | - These logits are typically the output of a softmax or sigmoid activation. |
| 98 | + # Compute Macro F1 (mean over all classes) |
| 99 | + return torch.mean(f1_per_class) |
139 | 100 |
|
140 | | - target : torch.Tensor |
141 | | - True labels (shape: [batch_size]), where each element is an integer representing the true class. |
| 101 | + def __returnmetric__(self): |
| 102 | + """ |
| 103 | + Computes and returns the F1 score (Micro or Macro). |
142 | 104 |
|
143 | 105 | Returns |
144 | 106 | ------- |
145 | 107 | torch.Tensor |
146 | | - The computed F1 score (either micro or macro, based on `macro_averaging`). |
| 108 | + The computed F1 score. |
147 | 109 | """ |
148 | | - preds = torch.argmax(preds, dim=-1) |
149 | | - self.y_true.append(target) |
150 | | - self.y_pred.append(preds) |
| 110 | + if not self.y_true or not self.y_pred: # Check if empty |
| 111 | + return torch.tensor(np.nan) |
| 112 | + |
| 113 | + # Convert lists to tensors |
| 114 | + y_true = torch.cat([t.unsqueeze(0) if t.dim() == 0 else t for t in self.y_true]) |
| 115 | + y_pred = torch.cat([t.unsqueeze(0) if t.dim() == 0 else t for t in self.y_pred]) |
151 | 116 |
|
152 | | - def __returnmetric__(self): |
153 | | - if self.y_true == [] or self.y_pred == []: |
154 | | - return np.nan |
155 | | - if isinstance(self.y_true, list): |
156 | | - if len(self.y_true) == 1: |
157 | | - self.y_true = self.y_true[0] |
158 | | - self.y_pred = self.y_pred[0] |
159 | | - else: |
160 | | - self.y_true = torch.cat(self.y_true) |
161 | | - self.y_pred = torch.cat(self.y_pred) |
162 | 117 | return ( |
163 | | - self._micro_F1(self.y_true, self.y_pred) |
164 | | - if not self.macro_averaging |
165 | | - else self._macro_F1(self.y_true, self.y_pred) |
| 118 | + self._macro_F1(y_true, y_pred) |
| 119 | + if self.macro_averaging |
| 120 | + else self._micro_F1(y_true, y_pred) |
166 | 121 | ) |
167 | 122 |
|
168 | 123 | def __reset__(self): |
| 124 | + """Resets stored predictions and targets.""" |
169 | 125 | self.y_true = [] |
170 | 126 | self.y_pred = [] |
171 | | - return None |
0 commit comments