1+ import numpy as np
12import torch
23import torch .nn as nn
34
@@ -18,6 +19,8 @@ def __init__(self, num_classes: int, macro_averaging: bool = False):
1819
1920 self .num_classes = num_classes
2021 self .macro_averaging = macro_averaging
22+ self .y_true = []
23+ self .y_pred = []
2124
2225 def forward (self , y_true : torch .tensor , logits : torch .tensor ) -> torch .tensor :
2326 """Compute precision of model
@@ -35,11 +38,10 @@ def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
3538 Precision score
3639 """
3740 y_pred = logits .argmax (dim = - 1 )
38- return (
39- self ._macro_avg_precision (y_true , y_pred )
40- if self .macro_averaging
41- else self ._micro_avg_precision (y_true , y_pred )
42- )
41+
42+ # Append to the class-global values
43+ self .y_true .append (y_true )
44+ self .y_pred .append (y_pred )
4345
4446 def _micro_avg_precision (
4547 self , y_true : torch .tensor , y_pred : torch .tensor
@@ -58,7 +60,6 @@ def _micro_avg_precision(
5860 torch.tensor
5961 Micro-averaged precision
6062 """
61- print (y_true .shape )
6263 true_oh = torch .zeros (y_true .size (0 ), self .num_classes ).scatter_ (
6364 1 , y_true .unsqueeze (1 ), 1
6465 )
@@ -98,6 +99,31 @@ def _macro_avg_precision(
9899
99100 return torch .nanmean (tp / (tp + fp ))
100101
102+ def __returnmetric__ (self ):
103+ if self .y_true == [] and self .y_pred == []:
104+ return np .nan
105+ elif self .y_true == [] or self .y_pred == []:
106+ raise ValueError ("y_true or y_pred is empty." )
107+ self .y_true = torch .cat (self .y_true )
108+ self .y_pred = torch .cat (self .y_pred )
109+
110+ return (
111+ self ._macro_avg_precision (self .y_true , self .y_pred )
112+ if self .macro_averaging
113+ else self ._micro_avg_precision (self .y_true , self .y_pred )
114+ )
115+
116+ def __reset__ (self ):
117+ """Resets the class-global lists of true and predicted values to empty lists.
118+
119+ Returns
120+ -------
121+ None
122+ Returns None
123+ """
124+ self .y_true = self .y_pred = []
125+ return None
126+
101127
102128if __name__ == "__main__" :
103129 print (
0 commit comments