Skip to content

Commit d742fe6

Browse files
authored
Merge pull request #34 from SFI-Visual-Intelligence/johan/test
All seems to be working here 👍
2 parents 0b21d9d + d128e58 commit d742fe6

File tree

3 files changed

+64
-53
lines changed

3 files changed

+64
-53
lines changed

tests/test_metrics.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from utils.metrics import F1Score, Recall
1+
2+
from utils.metrics import F1Score, Precision, Recall
3+
24

35

46
def test_recall():
@@ -30,3 +32,55 @@ def test_f1score():
3032
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
3133
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
3234
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
35+
36+
37+
def test_precision_case1():
38+
import torch
39+
40+
for boolean, true_precision in zip([True, False], [25.0 / 36, 7.0 / 10]):
41+
true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
42+
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
43+
P = Precision(3, use_mean=boolean)
44+
precision1 = P(true1, pred1)
45+
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
46+
f"Precision Score: {precision1.item()}"
47+
)
48+
49+
50+
def test_precision_case2():
51+
import torch
52+
53+
for boolean, true_precision in zip([True, False], [8.0 / 15, 6.0 / 15]):
54+
true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
55+
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
56+
P = Precision(5, use_mean=boolean)
57+
precision2 = P(true2, pred2)
58+
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
59+
f"Precision Score: {precision2.item()}"
60+
)
61+
62+
63+
def test_precision_case3():
64+
import torch
65+
66+
for boolean, true_precision in zip([True, False], [3.0 / 4, 4.0 / 5]):
67+
true3 = torch.tensor([0, 0, 0, 1, 0])
68+
pred3 = torch.tensor([1, 0, 0, 1, 0])
69+
P = Precision(2, use_mean=boolean)
70+
precision3 = P(true3, pred3)
71+
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
72+
f"Precision Score: {precision3.item()}"
73+
)
74+
75+
76+
def test_for_zero_denominator():
77+
import torch
78+
79+
for boolean in [True, False]:
80+
true4 = torch.tensor([1, 1, 1, 1, 1])
81+
pred4 = torch.tensor([0, 0, 0, 0, 0])
82+
P = Precision(2, use_mean=boolean)
83+
precision4 = P(true4, pred4)
84+
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
85+
f"Precision Score: {precision4.item()}"
86+
)

utils/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["EntropyPrediction", "Recall", "F1Score"]
1+
__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision"]
22

33
from .EntropyPred import EntropyPrediction
44
from .F1 import F1Score
5+
from .precision import Precision
56
from .recall import Recall

utils/metrics/precision.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,23 @@
77

88

99
class Precision(nn.Module):
10-
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives. This is for now controller with the USE_MEAN macro.
10+
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives.
1111
1212
Parameters
1313
----------
1414
num_classes : int
1515
Number of classes in the dataset.
16+
use_mean : bool
17+
Whether to calculate precision as a mean of precisions or as a brute function of true positives and false positives.
1618
"""
1719

18-
def __init__(self, num_classes):
20+
def __init__(self, num_classes: int, use_mean: bool = True):
1921
super().__init__()
2022

2123
self.num_classes = num_classes
24+
self.use_mean = use_mean
2225

23-
def forward(self, y_true, y_pred):
26+
def forward(self, y_true: torch.tensor, y_pred: torch.tensor) -> torch.tensor:
2427
"""Calculates the precision score given number of classes and the true and predicted labels.
2528
2629
Parameters
@@ -43,7 +46,7 @@ def forward(self, y_true, y_pred):
4346
1, y_pred.unsqueeze(1), 1
4447
)
4548

46-
if USE_MEAN:
49+
if self.use_mean:
4750
tp = torch.sum(true_oh * pred_oh, 0)
4851
fp = torch.sum(~true_oh.bool() * pred_oh, 0)
4952

@@ -54,52 +57,5 @@ def forward(self, y_true, y_pred):
5457
return torch.nanmean(tp / (tp + fp))
5558

5659

57-
def test_precision_case1():
58-
true_precision = 25.0 / 36 if USE_MEAN else 7.0 / 10
59-
60-
true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
61-
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
62-
P = Precision(3)
63-
precision1 = P(true1, pred1)
64-
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
65-
f"Precision Score: {precision1.item()}"
66-
)
67-
68-
69-
def test_precision_case2():
70-
true_precision = 8.0 / 15 if USE_MEAN else 6.0 / 15
71-
72-
true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
73-
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
74-
P = Precision(5)
75-
precision2 = P(true2, pred2)
76-
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
77-
f"Precision Score: {precision2.item()}"
78-
)
79-
80-
81-
def test_precision_case3():
82-
true_precision = 3.0 / 4 if USE_MEAN else 4.0 / 5
83-
84-
true3 = torch.tensor([0, 0, 0, 1, 0])
85-
pred3 = torch.tensor([1, 0, 0, 1, 0])
86-
P = Precision(2)
87-
precision3 = P(true3, pred3)
88-
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
89-
f"Precision Score: {precision3.item()}"
90-
)
91-
92-
93-
def test_for_zero_denominator():
94-
true_precision = 0.0
95-
true4 = torch.tensor([1, 1, 1, 1, 1])
96-
pred4 = torch.tensor([0, 0, 0, 0, 0])
97-
P = Precision(2)
98-
precision4 = P(true4, pred4)
99-
assert precision4.allclose(torch.tensor(true_precision), atol=1e-5), (
100-
f"Precision Score: {precision4.item()}"
101-
)
102-
103-
10460
if __name__ == "__main__":
10561
pass

0 commit comments

Comments
 (0)