44
55
66class Accuracy (nn .Module ):
7+ """
8+ Computes the accuracy of a model's predictions.
9+
10+ Args
11+ ----------
12+ num_classes : int
13+ The number of classes in the classification task.
14+ macro_averaging : bool, optional
15+ If True, computes macro-average accuracy. Otherwise, computes micro-average accuracy. Default is False.
16+
17+
18+ Methods
19+ -------
20+ forward(y_true, y_pred)
21+ Stores the true and predicted labels. Typically called for each batch during the forward pass of a model.
22+ _macro_acc()
23+ Computes the macro-average accuracy.
24+ _micro_acc()
25+ Computes the micro-average accuracy.
26+ __returnmetric__()
27+ Returns the computed accuracy based on the averaging method for all stored predictions.
28+ __reset__()
29+ Resets the stored true and predicted labels.
30+
31+ Examples
32+ --------
33+ >>> y_true = torch.tensor([0, 1, 2, 3, 3])
34+ >>> y_pred = torch.tensor([0, 1, 2, 3, 0])
35+ >>> accuracy = Accuracy(num_classes=4)
36+ >>> accuracy(y_true, y_pred)
37+ >>> accuracy.__returnmetric__()
38+ 0.8
39+ >>> accuracy.__reset__()
40+ >>> accuracy.macro_averaging = True
41+ >>> accuracy(y_true, y_pred)
42+ >>> accuracy.__returnmetric__()
43+ 0.875
44+ """
45+
746 def __init__ (self , num_classes , macro_averaging = False ):
847 super ().__init__ ()
948 self .num_classes = num_classes
@@ -13,19 +52,14 @@ def __init__(self, num_classes, macro_averaging=False):
1352
1453 def forward (self , y_true , y_pred ):
1554 """
16- Compute the accuracy of the model .
55+ Store the true and predicted labels .
1756
1857 Parameters
1958 ----------
2059 y_true : torch.Tensor
2160 True labels.
2261 y_pred : torch.Tensor
23- Predicted labels.
24-
25- Returns
26- -------
27- float
28- Accuracy score.
62+ Predicted labels. Either a 1D tensor of shape (batch_size,) or a 2D tensor of shape (batch_size, num_classes).
2963 """
3064 if y_pred .dim () > 1 :
3165 y_pred = y_pred .argmax (dim = 1 )
@@ -34,14 +68,7 @@ def forward(self, y_true, y_pred):
3468
3569 def _macro_acc (self ):
3670 """
37- Compute the macro-average accuracy.
38-
39- Parameters
40- ----------
41- y_true : torch.Tensor
42- True labels.
43- y_pred : torch.Tensor
44- Predicted labels.
71+ Compute the macro-average accuracy on the stored predictions.
4572
4673 Returns
4774 -------
@@ -63,14 +90,7 @@ def _macro_acc(self):
6390
6491 def _micro_acc (self ):
6592 """
66- Compute the micro-average accuracy.
67-
68- Parameters
69- ----------
70- y_true : torch.Tensor
71- True labels.
72- y_pred : torch.Tensor
73- Predicted labels.
93+ Compute the micro-average accuracy on the stored predictions.
7494
7595 Returns
7696 -------
@@ -80,6 +100,14 @@ def _micro_acc(self):
80100 return (self .y_true == self .y_pred ).float ().mean ().item ()
81101
82102 def __returnmetric__ (self ):
103+ """
104+ Return the computed accuracy based on the averaging method for all stored predictions.
105+
106+ Returns
107+ -------
108+ float
109+ Computed accuracy score.
110+ """
83111 if self .y_true == [] or self .y_pred == []:
84112 return np .nan
85113 if isinstance (self .y_true , list ):
@@ -92,6 +120,9 @@ def __returnmetric__(self):
92120 return self ._micro_acc () if not self .macro_averaging else self ._macro_acc ()
93121
94122 def __reset__ (self ):
123+ """
124+ Reset the stored true and predicted labels.
125+ """
95126 self .y_true = []
96127 self .y_pred = []
97128 return None
0 commit comments