@@ -7,6 +7,8 @@ def __init__(self, num_classes, macro_averaging=False):
77 super ().__init__ ()
88 self .num_classes = num_classes
99 self .macro_averaging = macro_averaging
10+ self .y_true = []
11+ self .y_pred = []
1012
1113 def forward (self , y_true , y_pred ):
1214 """
@@ -26,12 +28,10 @@ def forward(self, y_true, y_pred):
2628 """
2729 if y_pred .dim () > 1 :
2830 y_pred = y_pred .argmax (dim = 1 )
29- if self .macro_averaging :
30- return self ._macro_acc (y_true , y_pred )
31- else :
32- return self ._micro_acc (y_true , y_pred )
31+ self .y_true .append (y_true )
32+ self .y_pred .append (y_pred )
3333
34- def _macro_acc (self , y_true , y_pred ):
34+ def _macro_acc (self ):
3535 """
3636 Compute the macro-average accuracy.
3737
@@ -47,7 +47,7 @@ def _macro_acc(self, y_true, y_pred):
4747 float
4848 Macro-average accuracy score.
4949 """
50- y_true , y_pred = y_true .flatten (), y_pred .flatten () # Ensure 1D shape
50+ y_true , y_pred = self . y_true .flatten (), self . y_pred .flatten () # Ensure 1D shape
5151
5252 classes = torch .unique (y_true ) # Find unique class labels
5353 acc_per_class = []
@@ -60,7 +60,7 @@ def _macro_acc(self, y_true, y_pred):
6060 macro_acc = torch .stack (acc_per_class ).mean ().item () # Average across classes
6161 return macro_acc
6262
63- def _micro_acc (self , y_true , y_pred ):
63+ def _micro_acc (self ):
6464 """
6565 Compute the micro-average accuracy.
6666
@@ -76,27 +76,56 @@ def _micro_acc(self, y_true, y_pred):
7676 float
7777 Micro-average accuracy score.
7878 """
79- return (y_true == y_pred ).float ().mean ().item ()
79+ print (self .y_true , self .y_pred )
80+ return (self .y_true == self .y_pred ).float ().mean ().item ()
81+
82+ def __returnmetric__ (self ):
83+ print (self .y_true , self .y_pred )
84+ print (self .y_true == [], self .y_pred == [])
85+ print (len (self .y_true ), len (self .y_pred ))
86+ print (type (self .y_true ), type (self .y_pred ))
87+ if self .y_true == [] or self .y_pred == []:
88+ return 0.0
89+ if isinstance (self .y_true ,list ):
90+ if len (self .y_true ) == 1 :
91+ self .y_true = self .y_true [0 ]
92+ self .y_pred = self .y_pred [0 ]
93+ else :
94+ self .y_true = torch .cat (self .y_true )
95+ self .y_pred = torch .cat (self .y_pred )
96+ return self ._micro_acc () if not self .macro_averaging else self ._macro_acc ()
97+
98+ def __resetmetric__ (self ):
99+ self .y_true = []
100+ self .y_pred = []
101+ return None
80102
81103
82104if __name__ == "__main__" :
83- accuracy = Accuracy (5 )
84- macro_accuracy = Accuracy (5 , macro_averaging = True )
85-
86- y_true = torch .tensor ([0 , 3 , 2 , 3 , 4 ])
87- y_pred = torch .tensor ([0 , 1 , 2 , 3 , 4 ])
88- print (accuracy (y_true , y_pred ))
89- print (macro_accuracy (y_true , y_pred ))
90-
91- y_true = torch .tensor ([0 , 3 , 2 , 3 , 4 ])
92- y_onehot_pred = torch .tensor (
93- [
94- [1 , 0 , 0 , 0 , 0 ],
95- [0 , 1 , 0 , 0 , 0 ],
96- [0 , 0 , 1 , 0 , 0 ],
97- [0 , 0 , 0 , 1 , 0 ],
98- [0 , 0 , 0 , 0 , 1 ],
99- ]
100- )
101- print (accuracy (y_true , y_onehot_pred ))
102- print (macro_accuracy (y_true , y_onehot_pred ))
105+ # Test the accuracy metric
106+ y_true = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ])
107+ y_pred = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ])
108+ accuracy = Accuracy (num_classes = 6 , macro_averaging = False )
109+ accuracy (y_true , y_pred )
110+ print (accuracy .__returnmetric__ ()) # 1.0
111+ accuracy .__resetmetric__ ()
112+ print (accuracy .__returnmetric__ ()) # 0.0
113+ y_pred = torch .tensor ([0 , 1 , 2 , 3 , 4 , 4 ])
114+ accuracy (y_true , y_pred )
115+ print (accuracy .__returnmetric__ ()) # 0.8333333134651184
116+ accuracy .__resetmetric__ ()
117+ print (accuracy .__returnmetric__ ()) # 0.0
118+ accuracy .macro_averaging = True
119+ accuracy (y_true , y_pred )
120+ y_true_1 = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ])
121+ y_pred_1 = torch .tensor ([0 , 1 , 2 , 3 , 4 , 4 ])
122+ accuracy (y_true_1 , y_pred_1 )
123+ print (accuracy .__returnmetric__ ()) # 0.9166666865348816
124+ #accuracy.__resetmetric__()
125+ #accuracy(y_true, y_pred)
126+ #accuracy(y_true_1, y_pred_1)
127+ accuracy .macro_averaging = False
128+ print (accuracy .__returnmetric__ ()) # 0.8333333134651184
129+ accuracy .__resetmetric__ ()
130+ print (accuracy .__returnmetric__ ()) # 0.0
131+ print (accuracy .__resetmetric__ ()) # None
0 commit comments