@@ -18,6 +18,8 @@ def __init__(self, num_classes: int, macro_averaging: bool = False):
1818
1919 self .num_classes = num_classes
2020 self .macro_averaging = macro_averaging
21+ self .y_true = []
22+ self .y_pred = []
2123
2224 def forward (self , y_true : torch .tensor , logits : torch .tensor ) -> torch .tensor :
2325 """Compute precision of model
@@ -35,17 +37,11 @@ def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
3537 Precision score
3638 """
3739 y_pred = logits .argmax (dim = - 1 )
38- return (
39- self ._macro_avg_precision (y_true , y_pred )
40- if self .macro_averaging
41- else self ._micro_avg_precision (y_true , y_pred )
42- )
40+
41+ # Append to the class-global values
42+ self .y_true .append (y_true )
43+ self .y_pred .append (y_pred )
4344
44- def accumulate (self ):
45- pass # TODO fill
46-
47- def reset (self ):
48- pass # TODO fill
4945
5046 def _micro_avg_precision (
5147 self , y_true : torch .tensor , y_pred : torch .tensor
@@ -64,7 +60,6 @@ def _micro_avg_precision(
6460 torch.tensor
6561 Micro-averaged precision
6662 """
67- print (y_true .shape )
6863 true_oh = torch .zeros (y_true .size (0 ), self .num_classes ).scatter_ (
6964 1 , y_true .unsqueeze (1 ), 1
7065 )
@@ -103,6 +98,27 @@ def _macro_avg_precision(
10398 fp = torch .sum (~ true_oh .bool () * pred_oh , 0 )
10499
105100 return torch .nanmean (tp / (tp + fp ))
101+
102+ def __returnmetric__ (self ):
103+ if self .y_true == [] and self .y_pred == []:
104+ return []
105+ elif self .y_true == [] or self .y_pred == []:
106+ raise ValueError ("y_true or y_pred is empty." )
107+ self .y_true = torch .cat (self .y_true )
108+ self .y_pred = torch .cat (self .y_pred )
109+
110+ return self ._macro_avg_precision (self .y_true , self .y_pred ) if self .macro_averaging else self ._micro_avg_precision (self .y_true , self .y_pred )
111+
112+ def __reset__ (self ):
113+ """Resets the class-global lists of true and predicted values to empty lists.
114+
115+ Returns
116+ -------
117+ None
118+ Returns None
119+ """
120+ self .y_true = self .y_pred = []
121+ return None
106122
107123
108124if __name__ == "__main__" :
0 commit comments