1+ import numpy as np
12import torch
23from torch import nn
34
@@ -7,6 +8,8 @@ def __init__(self, num_classes, macro_averaging=False):
78 super ().__init__ ()
89 self .num_classes = num_classes
910 self .macro_averaging = macro_averaging
11+ self .y_true = []
12+ self .y_pred = []
1013
1114 def forward (self , y_true , y_pred ):
1215 """
@@ -26,12 +29,10 @@ def forward(self, y_true, y_pred):
2629 """
2730 if y_pred .dim () > 1 :
2831 y_pred = y_pred .argmax (dim = 1 )
29- if self .macro_averaging :
30- return self ._macro_acc (y_true , y_pred )
31- else :
32- return self ._micro_acc (y_true , y_pred )
32+ self .y_true .append (y_true )
33+ self .y_pred .append (y_pred )
3334
34- def _macro_acc (self , y_true , y_pred ):
35+ def _macro_acc (self ):
3536 """
3637 Compute the macro-average accuracy.
3738
@@ -47,7 +48,7 @@ def _macro_acc(self, y_true, y_pred):
4748 float
4849 Macro-average accuracy score.
4950 """
50- y_true , y_pred = y_true .flatten (), y_pred .flatten () # Ensure 1D shape
51+ y_true , y_pred = self . y_true .flatten (), self . y_pred .flatten () # Ensure 1D shape
5152
5253 classes = torch .unique (y_true ) # Find unique class labels
5354 acc_per_class = []
@@ -60,7 +61,7 @@ def _macro_acc(self, y_true, y_pred):
6061 macro_acc = torch .stack (acc_per_class ).mean ().item () # Average across classes
6162 return macro_acc
6263
63- def _micro_acc (self , y_true , y_pred ):
64+ def _micro_acc (self ):
6465 """
6566 Compute the micro-average accuracy.
6667
@@ -76,27 +77,21 @@ def _micro_acc(self, y_true, y_pred):
7677 float
7778 Micro-average accuracy score.
7879 """
79- return (y_true == y_pred ).float ().mean ().item ()
80-
81-
82- if __name__ == "__main__" :
83- accuracy = Accuracy (5 )
84- macro_accuracy = Accuracy (5 , macro_averaging = True )
85-
86- y_true = torch .tensor ([0 , 3 , 2 , 3 , 4 ])
87- y_pred = torch .tensor ([0 , 1 , 2 , 3 , 4 ])
88- print (accuracy (y_true , y_pred ))
89- print (macro_accuracy (y_true , y_pred ))
90-
91- y_true = torch .tensor ([0 , 3 , 2 , 3 , 4 ])
92- y_onehot_pred = torch .tensor (
93- [
94- [1 , 0 , 0 , 0 , 0 ],
95- [0 , 1 , 0 , 0 , 0 ],
96- [0 , 0 , 1 , 0 , 0 ],
97- [0 , 0 , 0 , 1 , 0 ],
98- [0 , 0 , 0 , 0 , 1 ],
99- ]
100- )
101- print (accuracy (y_true , y_onehot_pred ))
102- print (macro_accuracy (y_true , y_onehot_pred ))
80+ return (self .y_true == self .y_pred ).float ().mean ().item ()
81+
82+ def __returnmetric__ (self ):
83+ if self .y_true == [] or self .y_pred == []:
84+ return np .nan
85+ if isinstance (self .y_true , list ):
86+ if len (self .y_true ) == 1 :
87+ self .y_true = self .y_true [0 ]
88+ self .y_pred = self .y_pred [0 ]
89+ else :
90+ self .y_true = torch .cat (self .y_true )
91+ self .y_pred = torch .cat (self .y_pred )
92+ return self ._micro_acc () if not self .macro_averaging else self ._macro_acc ()
93+
94+ def __reset__ (self ):
95+ self .y_true = []
96+ self .y_pred = []
97+ return None
0 commit comments