Skip to content

Commit 562800d

Browse files
committed
added precision test to test_metrics.py
1 parent b1a3627 commit 562800d

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

tests/test_metrics.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from utils.metrics import Recall, F1Score
1+
from utils.metrics import F1Score, Precision, Recall
22

33

44
def test_recall():
@@ -30,3 +30,55 @@ def test_f1score():
3030
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
3131
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
3232
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
33+
34+
35+
def test_precision_case1():
36+
import torch
37+
38+
for boolean, true_precision in zip([True, False], [25.0 / 36, 7.0 / 10]):
39+
true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
40+
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
41+
P = Precision(3, use_mean=boolean)
42+
precision1 = P(true1, pred1)
43+
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
44+
f"Precision Score: {precision1.item()}"
45+
)
46+
47+
48+
def test_precision_case2():
49+
import torch
50+
51+
for boolean, true_precision in zip([True, False], [8.0 / 15, 6.0 / 15]):
52+
true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
53+
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
54+
P = Precision(5, use_mean=boolean)
55+
precision2 = P(true2, pred2)
56+
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
57+
f"Precision Score: {precision2.item()}"
58+
)
59+
60+
61+
def test_precision_case3():
62+
import torch
63+
64+
for boolean, true_precision in zip([True, False], [3.0 / 4, 4.0 / 5]):
65+
true3 = torch.tensor([0, 0, 0, 1, 0])
66+
pred3 = torch.tensor([1, 0, 0, 1, 0])
67+
P = Precision(2, use_mean=boolean)
68+
precision3 = P(true3, pred3)
69+
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
70+
f"Precision Score: {precision3.item()}"
71+
)
72+
73+
74+
def test_for_zero_denominator():
75+
import torch
76+
77+
for boolean in [True, False]:
78+
true4 = torch.tensor([1, 1, 1, 1, 1])
79+
pred4 = torch.tensor([0, 0, 0, 0, 0])
80+
P = Precision(2, use_mean=boolean)
81+
precision4 = P(true4, pred4)
82+
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
83+
f"Precision Score: {precision4.item()}"
84+
)

0 commit comments

Comments
 (0)