66class F1Score (nn .Module ):
77 """
88 F1 Score implementation with support for both macro and micro averaging.
9-
109 This class computes the F1 score during training using either macro or micro averaging.
11- The F1 score is calculated based on the true positives (TP), false positives (FP),
12- and false negatives (FN) for each class.
13-
1410 Parameters
1511 ----------
1612 num_classes : int
1713 The number of classes in the classification task.
1814
19- macro_averaging : bool, optional, default=False
15+ macro_averaging : bool, default=False
2016 If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
21-
22- Attributes
23- ----------
24- num_classes : int
25- The number of classes in the classification task.
26-
27- tp : torch.Tensor
28- Tensor storing the count of True Positives (TP) for each class.
29-
30- fp : torch.Tensor
31- Tensor storing the count of False Positives (FP) for each class.
32-
33- fn : torch.Tensor
34- Tensor storing the count of False Negatives (FN) for each class.
35-
36- macro_averaging : bool
37- A flag indicating whether to compute the macro-averaged F1 score or not.
3817 """
3918
4019 def __init__ (self , num_classes , macro_averaging = False ):
41- """
42- Initializes the F1Score object, setting up the necessary state variables.
43-
44- Parameters
45- ----------
46- num_classes : int
47- The number of classes in the classification task.
48-
49- macro_averaging : bool, optional, default=False
50- If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
51- """
5220 super ().__init__ ()
53-
5421 self .num_classes = num_classes
5522 self .macro_averaging = macro_averaging
5623 self .y_true = []
5724 self .y_pred = []
58- # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
59- self .tp = torch .zeros (num_classes )
60- self .fp = torch .zeros (num_classes )
61- self .fn = torch .zeros (num_classes )
6225
63- def _micro_F1 (self , target , preds ):
26+
27+ def forward (self , target , preds ):
6428 """
65- Compute the Micro F1 score by aggregating TP, FP, and FN across all classes .
29+ Stores predictions and targets for computing the F1 score .
6630
67- Micro F1 score is calculated globally by considering all predictions together, regardless of class.
31+ Parameters
32+ ----------
33+ preds : torch.Tensor
34+ Predicted logits (shape: [batch_size, num_classes]).
35+ target : torch.Tensor
36+ True labels (shape: [batch_size]).
37+ """
38+ preds = torch .argmax (preds , dim = - 1 ) # Convert logits to class indices
39+ self .y_true .append (target .detach ())
40+ if preds .dim () == 0 : # Scalar (e.g., single class prediction)
41+ preds = preds .unsqueeze (0 ) # Add batch dimension
42+ self .y_pred .append (preds .detach ())
43+
44+ def compute_f1 (self ):
45+ """
46+ Computes the F1 score (Micro or Macro).
6847
6948 Returns
7049 -------
7150 torch.Tensor
72- The micro-averaged F1 score.
51+ The computed F1 score.
7352 """
74- for i in range (self .num_classes ):
75- self .tp [i ] += torch .sum ((preds == i ) & (target == i )).float ()
76- self .fp [i ] += torch .sum ((preds == i ) & (target != i )).float ()
77- self .fn [i ] += torch .sum ((preds != i ) & (target == i )).float ()
53+ if not self .y_true or not self .y_pred : # Check if empty
54+ return torch .tensor (np .nan )
7855
79- tp = torch .sum (self .tp )
80- fp = torch .sum (self .fp )
81- fn = torch .sum (self .fn )
56+ # Convert lists to tensors
57+ y_true = torch .cat (self .y_true )
58+ y_pred = torch .cat (self .y_pred )
59+
60+ return self ._macro_F1 (y_true , y_pred ) if self .macro_averaging else self ._micro_F1 (y_true , y_pred )
61+
62+ def _micro_F1 (self , target , preds ):
63+ """Computes Micro F1 Score (global TP, FP, FN)."""
64+ tp = torch .sum (preds == target ).float ()
65+ fp = torch .sum (preds != target ).float ()
66+ fn = fp # Since all errors are either FP or FN
8267
83- precision = tp / (tp + fp + 1e-8 ) # Avoid division by zero
84- recall = tp / (tp + fn + 1e-8 ) # Avoid division by zero
68+ precision = tp / (tp + fp + 1e-8 )
69+ recall = tp / (tp + fn + 1e-8 )
70+ f1 = 2 * (precision * recall ) / (precision + recall + 1e-8 )
8571
86- f1 = (
87- 2 * precision * recall / (precision + recall + 1e-8 )
88- ) # Avoid division by zero
8972 return f1
9073
9174 def _macro_F1 (self , target , preds ):
92- """
93- Compute the Macro F1 score by calculating the F1 score per class and averaging.
94-
95- Macro F1 score is calculated as the average of per-class F1 scores. This approach treats all classes equally,
96- regardless of their frequency.
75+ """Computes Macro F1 Score in a vectorized way (no loops)."""
76+ num_classes = self .num_classes
77+ target = target .long () # Ensure target is a LongTensor
78+ preds = preds .long ()
79+ # Create one-hot encodings of the true and predicted labels
80+ target_one_hot = torch .nn .functional .one_hot (target , num_classes = num_classes )
81+ preds_one_hot = torch .nn .functional .one_hot (preds , num_classes = num_classes )
9782
98- Returns
99- -------
100- torch.Tensor
101- The macro-averaged F1 score.
102- """
103- # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
104- for i in range (self .num_classes ):
105- self .tp [i ] += torch .sum ((preds == i ) & (target == i )).float ()
106- self .fp [i ] += torch .sum ((preds == i ) & (target != i )).float ()
107- self .fn [i ] += torch .sum ((preds != i ) & (target == i )).float ()
108-
109- precision_per_class = self .tp / (
110- self .tp + self .fp + 1e-8
111- ) # Avoid division by zero
112- recall_per_class = self .tp / (
113- self .tp + self .fn + 1e-8
114- ) # Avoid division by zero
115- f1_per_class = (
116- 2
117- * precision_per_class
118- * recall_per_class
119- / (precision_per_class + recall_per_class + 1e-8 )
120- ) # Avoid division by zero
121-
122- # Take the average of F1 scores across all classes
123- f1_score = torch .mean (f1_per_class )
124- return f1_score
125-
126- def forward (self , preds , target ):
127- """
83+ # Compute TP, FP, FN for each class
84+ tp = torch .sum (target_one_hot * preds_one_hot , dim = 0 ).float ()
85+ fp = torch .sum (preds_one_hot * (1 - target_one_hot ), dim = 0 ).float ()
86+ fn = torch .sum (target_one_hot * (1 - preds_one_hot ), dim = 0 ).float ()
12887
129- Update the True Positives, False Positives, and False Negatives, and compute the F1 score.
88+ # Compute precision and recall per class
89+ precision = tp / (tp + fp + 1e-8 )
90+ recall = tp / (tp + fn + 1e-8 )
13091
131- This method computes the F1 score based on the predictions and true labels. It can compute either the
132- macro-averaged or micro-averaged F1 score, depending on the `macro_averaging` flag.
92+ # Compute per-class F1 score
93+ f1_per_class = 2 * ( precision * recall ) / ( precision + recall + 1e-8 )
13394
134- Parameters
135- ----------
136- preds : torch.Tensor
137- Predicted logits or class indices (shape: [batch_size, num_classes]).
138- These logits are typically the output of a softmax or sigmoid activation.
95+ # Compute Macro F1 (mean over all classes)
96+ return torch .mean (f1_per_class )
13997
140- target : torch.Tensor
141- True labels (shape: [batch_size]), where each element is an integer representing the true class.
98+ def __returnmetric__ (self ):
99+ """
100+ Computes and returns the F1 score (Micro or Macro).
142101
143102 Returns
144103 -------
145104 torch.Tensor
146- The computed F1 score (either micro or macro, based on `macro_averaging`) .
105+ The computed F1 score.
147106 """
148- preds = torch .argmax (preds , dim = - 1 )
149- self .y_true .append (target )
150- self .y_pred .append (preds )
107+ if not self .y_true or not self .y_pred : # Check if empty
108+ return torch .tensor (np .nan )
151109
152- def __returnmetric__ (self ):
153- if self .y_true == [] or self .y_pred == []:
154- return np .nan
155- if isinstance (self .y_true , list ):
156- if len (self .y_true ) == 1 :
157- self .y_true = self .y_true [0 ]
158- self .y_pred = self .y_pred [0 ]
159- else :
160- self .y_true = torch .cat (self .y_true )
161- self .y_pred = torch .cat (self .y_pred )
162- return (
163- self ._micro_F1 (self .y_true , self .y_pred )
164- if not self .macro_averaging
165- else self ._macro_F1 (self .y_true , self .y_pred )
166- )
110+ # Convert lists to tensors
111+ y_true = torch .cat ([t .unsqueeze (0 ) if t .dim () == 0 else t for t in self .y_true ])
112+ y_pred = torch .cat ([t .unsqueeze (0 ) if t .dim () == 0 else t for t in self .y_pred ])
113+
114+ return self ._macro_F1 (y_true , y_pred ) if self .macro_averaging else self ._micro_F1 (y_true , y_pred )
167115
168116 def __reset__ (self ):
117+ """Resets stored predictions and targets."""
169118 self .y_true = []
170- self .y_pred = []
171- return None
119+ self .y_pred = []
0 commit comments