@@ -13,49 +13,95 @@ class Precision(nn.Module):
1313 ----------
1414 num_classes : int
1515 Number of classes in the dataset.
16- use_mean : bool
17- Whether to calculate precision as a mean of precisions or as a brute function of true positives and false positives.
16+ micro_averaging : bool
17+ Wheter to compute the micro or macro precision (default False)
1818 """
1919
20- def __init__ (self , num_classes : int , use_mean : bool = True ):
20+ def __init__ (self , num_classes : int , micro_averaging : bool = False ):
2121 super ().__init__ ()
2222
2323 self .num_classes = num_classes
24- self .use_mean = use_mean
24+ self ._micro_averaging = micro_averaging
2525
2626 def forward (self , y_true : torch .tensor , y_pred : torch .tensor ) -> torch .tensor :
27- """Calculates the precision score given number of classes and the true and predicted labels.
27+ """Compute precision of model
2828
2929 Parameters
3030 ----------
3131 y_true : torch.tensor
32- true labels
32+ True labels
3333 y_pred : torch.tensor
34- predicted labels
34+ Predicted labels
3535
3636 Returns
3737 -------
3838 torch.tensor
39- precision score
39+ Precision score
40+ """
41+ return (
42+ self ._micro_avg_precision (y_true , y_pred )
43+ if self .micro_averaging
44+ else self ._macro_avg_precision (y_true , y_pred )
45+ )
46+
47+ def _micro_avg_precision (
48+ self , y_true : torch .tensor , y_pred : torch .tensor
49+ ) -> torch .tensor :
50+ """Compute micro-average precision by first calculating true/false positive across all classes and then find the precision.
51+
52+ Parameters
53+ ----------
54+ y_true : torch.tensor
55+ True labels
56+ y_pred : torch.tensor
57+ Predicted labels
58+
59+ Returns
60+ -------
61+ torch.tensor
62+ Micro-averaged precision
4063 """
41- # One-hot encode the target tensor
4264 true_oh = torch .zeros (y_true .size (0 ), self .num_classes ).scatter_ (
4365 1 , y_true .unsqueeze (1 ), 1
4466 )
4567 pred_oh = torch .zeros (y_pred .size (0 ), self .num_classes ).scatter_ (
4668 1 , y_pred .unsqueeze (1 ), 1
4769 )
70+ tp = torch .sum (true_oh * pred_oh )
71+ fp = torch .sum (~ true_oh [pred_oh .bool ()].bool ())
4872
49- if self .use_mean :
50- tp = torch .sum (true_oh * pred_oh , 0 )
51- fp = torch .sum (~ true_oh .bool () * pred_oh , 0 )
73+ return torch .nanmean (tp / (tp + fp ))
74+
75+ def _macro_avg_precision (
76+ self , y_true : torch .tensor , y_pred : torch .tensor
77+ ) -> torch .tensor :
78+ """Compute macro-average precision by finding true/false positives of each class separately then averaging across all classes.
5279
53- else :
54- tp = torch .sum (true_oh * pred_oh )
55- fp = torch .sum (~ true_oh [pred_oh .bool ()].bool ())
80+ Parameters
81+ ----------
82+ y_true : torch.tensor
83+ True labels
84+ y_pred : torch.tensor
85+ Predicted labels
86+
87+ Returns
88+ -------
89+ torch.tensor
90+ Macro-averaged precision
91+ """
92+ true_oh = torch .zeros (y_true .size (0 ), self .num_classes ).scatter_ (
93+ 1 , y_true .unsqueeze (1 ), 1
94+ )
95+ pred_oh = torch .zeros (y_pred .size (0 ), self .num_classes ).scatter_ (
96+ 1 , y_pred .unsqueeze (1 ), 1
97+ )
98+ tp = torch .sum (true_oh * pred_oh , 0 )
99+ fp = torch .sum (~ true_oh .bool () * pred_oh , 0 )
56100
57101 return torch .nanmean (tp / (tp + fp ))
58102
59103
60104if __name__ == "__main__" :
61- pass
105+ print (
106+ "Congratulations, you succesfully ran the Precision metric class. You should be proud of this marvelous achievement!"
107+ )
0 commit comments