22import torchmetrics
33
44
5- def custom_reduce_fx (input ):
5+ def custom_reduce_fx (input : torch .Tensor ) -> torch .Tensor :
6+ """
7+ Custom reduction function for distributed training.
8+
9+ Args:
10+ input (torch.Tensor): The input tensor to be reduced.
11+
12+ Returns:
13+ torch.Tensor: The reduced tensor.
14+ """
615 print (f"called reduce (device: { input .device } )" )
716 return torch .sum (input , dim = 0 )
817
918
1019class MacroF1 (torchmetrics .Metric ):
11- def __init__ (self , num_labels , dist_sync_on_step = False , threshold = 0.5 ):
20+ """
21+ Computes the Macro F1 score, which is the unweighted mean of F1 scores for each class.
22+ This implementation differs from torchmetrics.classification.MultilabelF1Score in the behaviour for undefined
23+ values (i.e., classes where TP+FN=0). The torchmetrics implementation sets these classes to a default value.
24+ Here, the mean is only taken over classes which have at least one positive sample.
25+
26+ Args:
27+ num_labels (int): Number of classes/labels.
28+ dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward
29+ before returning the value at the step. Default: False.
30+ threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions.
31+ Default: 0.5.
32+ """
33+
34+ def __init__ (
35+ self , num_labels : int , dist_sync_on_step : bool = False , threshold : float = 0.5
36+ ):
1237 super ().__init__ (dist_sync_on_step = dist_sync_on_step )
1338
1439 self .add_state (
@@ -28,15 +53,29 @@ def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
2853 )
2954 self .threshold = threshold
3055
31- def update (self , preds : torch .Tensor , labels : torch .Tensor ):
56+ def update (self , preds : torch .Tensor , labels : torch .Tensor ) -> None :
57+ """
58+ Update the state (TPs, Positive Predictions, Positive labels) with the current batch of predictions and labels.
59+
60+ Args:
61+ preds (torch.Tensor): Predictions from the model.
62+ labels (torch.Tensor): Ground truth labels.
63+ """
3264 tps = torch .sum (
3365 torch .logical_and (preds > self .threshold , labels .to (torch .bool )), dim = 0
3466 )
3567 self .true_positives += tps
3668 self .positive_predictions += torch .sum (preds > self .threshold , dim = 0 )
3769 self .positive_labels += torch .sum (labels , dim = 0 )
3870
39- def compute (self ):
71+ def compute (self ) -> torch .Tensor :
72+ """
73+ Compute the Macro F1 score.
74+
75+ Returns:
76+ torch.Tensor: The computed Macro F1 score.
77+ """
78+
4079 # ignore classes without positive labels
4180 # classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0),
4281 # which is propagated to the classwise_f1 and then turned into 0
@@ -50,14 +89,22 @@ def compute(self):
5089
5190
5291class BalancedAccuracy (torchmetrics .Metric ):
53- """Balanced Accuracy = (TPR + TNR) / 2 = ( TP/(TP + FN) + (TN)/(TN + FP) ) / 2
54-
55- This metric computes the balanced accuracy, which is the average of true positive rate (TPR)
56- and true negative rate (TNR). It is useful for imbalanced datasets where the classes are not
57- represented equally.
92+ """
93+ Computes the Balanced Accuracy, which is the average of true positive rate (TPR) and true negative rate (TNR).
94+ Useful for imbalanced datasets.
95+ Balanced Accuracy = (TPR + TNR)/2 = (TP/(TP + FN) + (TN)/(TN + FP))/2
96+
97+ Args:
98+ num_labels (int): Number of classes/labels.
99+ dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward
100+ before returning the value at the step. Default: False.
101+ threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions.
102+ Default: 0.5.
58103 """
59104
60- def __init__ (self , num_labels , dist_sync_on_step = False , threshold = 0.5 ):
105+ def __init__ (
106+ self , num_labels : int , dist_sync_on_step : bool = False , threshold : float = 0.5
107+ ):
61108 super ().__init__ (dist_sync_on_step = dist_sync_on_step )
62109
63110 self .add_state (
@@ -86,8 +133,14 @@ def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
86133
87134 self .threshold = threshold
88135
89- def update (self , preds : torch .Tensor , labels : torch .Tensor ):
90- """Update the TPs, TNs ,FPs and FNs"""
136+ def update (self , preds : torch .Tensor , labels : torch .Tensor ) -> None :
137+ """
138+ Update the state (TPs, TNs, FPs, FNs) with the current batch of predictions and labels.
139+
140+ Args:
141+ preds (torch.Tensor): Predictions from the model.
142+ labels (torch.Tensor): Ground truth labels.
143+ """
91144
92145 # Size: Batch_size x Num_of_Classes;
93146 # summing over 1st dimension (dim=0), gives us the True positives per class
@@ -110,9 +163,13 @@ def update(self, preds: torch.Tensor, labels: torch.Tensor):
110163 self .true_negatives += tns
111164 self .false_negatives += fns
112165
113- def compute (self ):
114- """Compute the average value of Balanced accuracy from each batch"""
166+ def compute (self ) -> torch .Tensor :
167+ """
168+ Compute the Balanced Accuracy.
115169
170+ Returns:
171+ torch.Tensor: The computed Balanced Accuracy.
172+ """
116173 tpr = self .true_positives / (self .true_positives + self .false_negatives )
117174 tnr = self .true_negatives / (self .true_negatives + self .false_positives )
118175 # Convert the nan values to 0
0 commit comments