Skip to content

Commit ed0eaf2

Browse files
authored
Merge pull request #39 from SFI-Visual-Intelligence/Jan/accuracy
Added accuracy and tests for it and Jan model, no clashes merging myself
2 parents d742fe6 + 46798d2 commit ed0eaf2

File tree

5 files changed

+70
-8
lines changed

5 files changed

+70
-8
lines changed

tests/test_metrics.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
2-
from utils.metrics import F1Score, Precision, Recall
3-
1+
from utils.metrics import Accuracy, F1Score, Precision, Recall
42

53

64
def test_recall():
@@ -84,3 +82,18 @@ def test_for_zero_denominator():
8482
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
8583
f"Precision Score: {precision4.item()}"
8684
)
85+
86+
87+
def test_accuracy():
88+
import torch
89+
90+
accuracy = Accuracy()
91+
92+
y_true = torch.tensor([0, 3, 2, 3, 4])
93+
y_pred = torch.tensor([0, 1, 2, 3, 4])
94+
95+
accuracy_score = accuracy(y_true, y_pred)
96+
97+
assert (torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5), (
98+
f"Accuracy Score: {accuracy_score.item()}"
99+
)

tests/test_models.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from utils.models import ChristianModel
4+
from utils.models import ChristianModel, JanModel
55

66

77
@pytest.mark.parametrize(
@@ -20,3 +20,18 @@ def test_christian_model(image_shape, num_classes):
2020
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), (
2121
f"Softmax output should sum to 1, but got: {y.sum()}"
2222
)
23+
24+
25+
@pytest.mark.parametrize(
26+
"image_shape, num_classes",
27+
[((1, 28, 28), 4), ((3, 16, 16), 10)],
28+
)
29+
def test_jan_model(image_shape, num_classes):
30+
n, c, h, w = 5, *image_shape
31+
32+
model = JanModel(image_shape, num_classes)
33+
34+
x = torch.randn(n, c, h, w)
35+
y = model(x)
36+
37+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"

utils/load_metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch.nn as nn
55

6-
from .metrics import EntropyPrediction, F1Score, precision
6+
from .metrics import Accuracy, EntropyPrediction, F1Score, Precision
77

88

99
class MetricWrapper(nn.Module):
@@ -39,9 +39,9 @@ def _get_metric(self, key):
3939
case "recall":
4040
raise NotImplementedError("Recall score not implemented yet")
4141
case "precision":
42-
return precision()
42+
return Precision()
4343
case "accuracy":
44-
raise NotImplementedError("Accuracy score not implemented yet")
44+
return Accuracy()
4545
case _:
4646
raise ValueError(f"Metric {key} not supported")
4747

utils/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision"]
1+
__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision", "Accuracy"]
22

3+
from .accuracy import Accuracy
34
from .EntropyPred import EntropyPrediction
45
from .F1 import F1Score
56
from .precision import Precision

utils/metrics/accuracy.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class Accuracy(nn.Module):
6+
def __init__(self):
7+
super().__init__()
8+
9+
def forward(self, y_true, y_pred):
10+
"""
11+
Compute the accuracy of the model.
12+
13+
Parameters
14+
----------
15+
y_true : torch.Tensor
16+
True labels.
17+
y_pred : torch.Tensor
18+
Predicted labels.
19+
20+
Returns
21+
-------
22+
float
23+
Accuracy score.
24+
"""
25+
return (y_true == y_pred).float().mean().item()
26+
27+
28+
if __name__ == "__main__":
29+
y_true = torch.tensor([0, 3, 2, 3, 4])
30+
y_pred = torch.tensor([0, 1, 2, 3, 4])
31+
32+
accuracy = Accuracy()
33+
print(accuracy(y_true, y_pred))

0 commit comments

Comments
 (0)