Skip to content

Commit 24f920d

Browse files
committed
adjusted accuracy test and test_metric_wrapper to work with new method names
1 parent 787a342 commit 24f920d

File tree

2 files changed

+28
-39
lines changed

2 files changed

+28
-39
lines changed

CollaborativeCoding/metrics/accuracy.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from torch import nn
3+
import numpy as np
34

45

56
class Accuracy(nn.Module):
@@ -80,7 +81,7 @@ def _micro_acc(self):
8081

8182
def __returnmetric__(self):
8283
if self.y_true == [] or self.y_pred == []:
83-
return 0.0
84+
return np.nan
8485
if isinstance(self.y_true,list):
8586
if len(self.y_true) == 1:
8687
self.y_true = self.y_true[0]
@@ -96,28 +97,3 @@ def __reset__(self):
9697
return None
9798

9899

99-
if __name__ == "__main__":
100-
# Test the accuracy metric
101-
y_true = torch.tensor([0, 1, 2, 3, 4, 5])
102-
y_pred = torch.tensor([0, 1, 2, 3, 4, 5])
103-
accuracy = Accuracy(num_classes=6, macro_averaging=False)
104-
accuracy(y_true, y_pred)
105-
print(accuracy.__returnmetric__()) # 1.0
106-
accuracy.__resetmetric__()
107-
print(accuracy.__returnmetric__()) # 0.0
108-
y_pred = torch.tensor([0, 1, 2, 3, 4, 4])
109-
accuracy(y_true, y_pred)
110-
print(accuracy.__returnmetric__()) # 0.8333333134651184
111-
accuracy.__resetmetric__()
112-
print(accuracy.__returnmetric__()) # 0.0
113-
accuracy.macro_averaging = True
114-
accuracy(y_true, y_pred)
115-
y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5])
116-
y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4])
117-
accuracy(y_true_1, y_pred_1)
118-
print(accuracy.__returnmetric__()) # 0.9166666865348816
119-
accuracy.macro_averaging = False
120-
print(accuracy.__returnmetric__()) # 0.8333333134651184
121-
accuracy.__resetmetric__()
122-
print(accuracy.__returnmetric__()) # 0.0
123-
print(accuracy.__resetmetric__()) # None

tests/test_metrics.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from CollaborativeCoding.load_metric import MetricWrapper
6-
from CollaborativeCoding.metrics import Accuracy, F1Score, Precision, Recall
6+
from CollaborativeCoding.metrics import Accuracy, F1Score, Precision, Recall, EntropyPrediction
77

88

99
@pytest.mark.parametrize(
@@ -34,9 +34,9 @@ def test_metric_wrapper(metric, num_classes, macro_averaging):
3434
)
3535

3636
metrics(y_true, logits)
37-
score = metrics.accumulate()
38-
metrics.reset()
39-
empty_score = metrics.accumulate()
37+
score = metrics.__getmetrics__()
38+
metrics.__resetmetrics__()
39+
empty_score = metrics.__getmetrics__()
4040

4141
assert isinstance(score, dict), "Expected a dictionary output."
4242
assert metric in score, f"Expected {metric} metric in the output."
@@ -129,17 +129,30 @@ def test_precision():
129129

130130
def test_accuracy():
131131
import torch
132+
import numpy as np
132133

133-
accuracy = Accuracy(num_classes=5)
134-
135-
y_true = torch.tensor([0, 3, 2, 3, 4])
136-
y_pred = torch.tensor([0, 1, 2, 3, 4])
137-
138-
accuracy_score = accuracy(y_true, y_pred)
134+
# Test the accuracy metric
135+
y_true = torch.tensor([0, 1, 2, 3, 4, 5])
136+
y_pred = torch.tensor([0, 1, 2, 3, 4, 5])
137+
accuracy = Accuracy(num_classes=6, macro_averaging=False)
138+
accuracy(y_true, y_pred)
139+
assert accuracy.__returnmetric__() == 1.0, "Expected accuracy to be 1.0"
140+
accuracy.__reset__()
141+
assert accuracy.__returnmetric__() is np.nan, "Expected accuracy to be 0.0"
142+
y_pred = torch.tensor([0, 1, 2, 3, 4, 4])
143+
accuracy(y_true, y_pred)
144+
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651184"
145+
accuracy.__reset__()
146+
accuracy.macro_averaging = True
147+
accuracy(y_true, y_pred)
148+
y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5])
149+
y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4])
150+
accuracy(y_true_1, y_pred_1)
151+
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651186"
152+
accuracy.macro_averaging = False
153+
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651184"
154+
accuracy.__reset__()
139155

140-
assert torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5, (
141-
f"Accuracy Score: {accuracy_score.item()}"
142-
)
143156

144157

145158
def test_entropypred():

0 commit comments

Comments
 (0)