Skip to content

Commit 8b22c30

Browse files
committed
adjusted accuracy
1 parent 4174cd4 commit 8b22c30

File tree

1 file changed

+57
-28
lines changed

1 file changed

+57
-28
lines changed

CollaborativeCoding/metrics/accuracy.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ def __init__(self, num_classes, macro_averaging=False):
77
super().__init__()
88
self.num_classes = num_classes
99
self.macro_averaging = macro_averaging
10+
self.y_true = []
11+
self.y_pred = []
1012

1113
def forward(self, y_true, y_pred):
1214
"""
@@ -26,12 +28,10 @@ def forward(self, y_true, y_pred):
2628
"""
2729
if y_pred.dim() > 1:
2830
y_pred = y_pred.argmax(dim=1)
29-
if self.macro_averaging:
30-
return self._macro_acc(y_true, y_pred)
31-
else:
32-
return self._micro_acc(y_true, y_pred)
31+
self.y_true.append(y_true)
32+
self.y_pred.append(y_pred)
3333

34-
def _macro_acc(self, y_true, y_pred):
34+
def _macro_acc(self):
3535
"""
3636
Compute the macro-average accuracy.
3737
@@ -47,7 +47,7 @@ def _macro_acc(self, y_true, y_pred):
4747
float
4848
Macro-average accuracy score.
4949
"""
50-
y_true, y_pred = y_true.flatten(), y_pred.flatten() # Ensure 1D shape
50+
y_true, y_pred = self.y_true.flatten(), self.y_pred.flatten() # Ensure 1D shape
5151

5252
classes = torch.unique(y_true) # Find unique class labels
5353
acc_per_class = []
@@ -60,7 +60,7 @@ def _macro_acc(self, y_true, y_pred):
6060
macro_acc = torch.stack(acc_per_class).mean().item() # Average across classes
6161
return macro_acc
6262

63-
def _micro_acc(self, y_true, y_pred):
63+
def _micro_acc(self):
6464
"""
6565
Compute the micro-average accuracy.
6666
@@ -76,27 +76,56 @@ def _micro_acc(self, y_true, y_pred):
7676
float
7777
Micro-average accuracy score.
7878
"""
79-
return (y_true == y_pred).float().mean().item()
79+
print(self.y_true, self.y_pred)
80+
return (self.y_true == self.y_pred).float().mean().item()
81+
82+
def __returnmetric__(self):
83+
print(self.y_true, self.y_pred)
84+
print(self.y_true == [], self.y_pred == [])
85+
print(len(self.y_true), len(self.y_pred))
86+
print(type(self.y_true), type(self.y_pred))
87+
if self.y_true == [] or self.y_pred == []:
88+
return 0.0
89+
if isinstance(self.y_true,list):
90+
if len(self.y_true) == 1:
91+
self.y_true = self.y_true[0]
92+
self.y_pred = self.y_pred[0]
93+
else:
94+
self.y_true = torch.cat(self.y_true)
95+
self.y_pred = torch.cat(self.y_pred)
96+
return self._micro_acc() if not self.macro_averaging else self._macro_acc()
97+
98+
def __resetmetric__(self):
99+
self.y_true = []
100+
self.y_pred = []
101+
return None
80102

81103

82104
if __name__ == "__main__":
83-
accuracy = Accuracy(5)
84-
macro_accuracy = Accuracy(5, macro_averaging=True)
85-
86-
y_true = torch.tensor([0, 3, 2, 3, 4])
87-
y_pred = torch.tensor([0, 1, 2, 3, 4])
88-
print(accuracy(y_true, y_pred))
89-
print(macro_accuracy(y_true, y_pred))
90-
91-
y_true = torch.tensor([0, 3, 2, 3, 4])
92-
y_onehot_pred = torch.tensor(
93-
[
94-
[1, 0, 0, 0, 0],
95-
[0, 1, 0, 0, 0],
96-
[0, 0, 1, 0, 0],
97-
[0, 0, 0, 1, 0],
98-
[0, 0, 0, 0, 1],
99-
]
100-
)
101-
print(accuracy(y_true, y_onehot_pred))
102-
print(macro_accuracy(y_true, y_onehot_pred))
105+
# Test the accuracy metric
106+
y_true = torch.tensor([0, 1, 2, 3, 4, 5])
107+
y_pred = torch.tensor([0, 1, 2, 3, 4, 5])
108+
accuracy = Accuracy(num_classes=6, macro_averaging=False)
109+
accuracy(y_true, y_pred)
110+
print(accuracy.__returnmetric__()) # 1.0
111+
accuracy.__resetmetric__()
112+
print(accuracy.__returnmetric__()) # 0.0
113+
y_pred = torch.tensor([0, 1, 2, 3, 4, 4])
114+
accuracy(y_true, y_pred)
115+
print(accuracy.__returnmetric__()) # 0.8333333134651184
116+
accuracy.__resetmetric__()
117+
print(accuracy.__returnmetric__()) # 0.0
118+
accuracy.macro_averaging = True
119+
accuracy(y_true, y_pred)
120+
y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5])
121+
y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4])
122+
accuracy(y_true_1, y_pred_1)
123+
print(accuracy.__returnmetric__()) # 0.9166666865348816
124+
#accuracy.__resetmetric__()
125+
#accuracy(y_true, y_pred)
126+
#accuracy(y_true_1, y_pred_1)
127+
accuracy.macro_averaging = False
128+
print(accuracy.__returnmetric__()) # 0.8333333134651184
129+
accuracy.__resetmetric__()
130+
print(accuracy.__returnmetric__()) # 0.0
131+
print(accuracy.__resetmetric__()) # None

0 commit comments

Comments
 (0)