Skip to content

Commit b9b7158

Browse files
committed
Updated precision metric and test function, need to discuss shape of y_true. is it ([N,]) or ([N,1])?
1 parent ba90d89 commit b9b7158

File tree

2 files changed

+45
-50
lines changed

2 files changed

+45
-50
lines changed

tests/test_metrics.py

Lines changed: 42 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -78,56 +78,49 @@ def test_f1score():
7878
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
7979

8080

81-
def test_precision_case1():
81+
def test_precision():
8282
import torch
83-
84-
for boolean, true_precision in zip([False, True], [25.0 / 36, 7.0 / 10]):
85-
true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
86-
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
87-
P = Precision(3, micro_averaging=boolean)
88-
precision1 = P(true1, pred1)
89-
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
90-
f"Precision Score: {precision1.item()}"
91-
)
92-
93-
94-
def test_precision_case2():
95-
import torch
96-
97-
for boolean, true_precision in zip([False, True], [8.0 / 15, 6.0 / 15]):
98-
true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
99-
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
100-
P = Precision(5, micro_averaging=boolean)
101-
precision2 = P(true2, pred2)
102-
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
103-
f"Precision Score: {precision2.item()}"
104-
)
105-
106-
107-
def test_precision_case3():
108-
import torch
109-
110-
for boolean, true_precision in zip([False, True], [3.0 / 4, 4.0 / 5]):
111-
true3 = torch.tensor([0, 0, 0, 1, 0])
112-
pred3 = torch.tensor([1, 0, 0, 1, 0])
113-
P = Precision(2, micro_averaging=boolean)
114-
precision3 = P(true3, pred3)
115-
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
116-
f"Precision Score: {precision3.item()}"
117-
)
118-
119-
120-
def test_for_zero_denominator():
121-
import torch
122-
123-
for boolean in [False, True]:
124-
true4 = torch.tensor([1, 1, 1, 1, 1])
125-
pred4 = torch.tensor([0, 0, 0, 0, 0])
126-
P = Precision(2, micro_averaging=boolean)
127-
precision4 = P(true4, pred4)
128-
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
129-
f"Precision Score: {precision4.item()}"
130-
)
83+
import numpy as np
84+
from sklearn.metrics import precision_score
85+
from random import randint
86+
87+
C = randint(2, 10) # number of classes
88+
N = randint(2,10*C) # batchsize
89+
y_true = torch.randint(0,C, (N,))
90+
logits = torch.randn(N, C)
91+
92+
# create metric objects
93+
precision_micro = Precision(num_classes=C)
94+
precision_macro = Precision(num_classes=C, macro_averaging=True)
95+
96+
# find scores
97+
micro_precision_score = precision_micro(y_true, logits)
98+
macro_precision_score = precision_macro(y_true, logits)
99+
100+
# check output to be tensor
101+
assert isinstance(micro_precision_score, torch.Tensor), "Tensor output is expected."
102+
assert isinstance(macro_precision_score, torch.Tensor), "Tensor output is expected."
103+
104+
# check for non-negativity
105+
assert micro_precision_score.item() >= 0, "Expected non-negative value"
106+
assert macro_precision_score.item() >= 0, "Expected non-negative value"
107+
108+
# find predictions
109+
y_pred = logits.argmax(dim=-1, keepdims=True)
110+
111+
# check dimension
112+
assert y_true.shape == torch.Size([N,1]) or torch.Size([N])
113+
assert logits.shape == torch.Size([N,C])
114+
assert y_pred.shape == torch.Size([N,1]) or torch.Size([N])
115+
116+
117+
# find true values with scikit learn
118+
scikit_macro_precision = precision_score(y_true, y_pred, average="macro")
119+
scikit_micro_precision = precision_score(y_true, y_pred, average="micro")
120+
121+
# check for similarity
122+
assert np.isclose(scikit_micro_precision, micro_precision_score, atol=1e-5), "Score does not match scikit's score"
123+
assert np.isclose(scikit_macro_precision, macro_precision_score, atol=1e-5), "Score does not match scikit's score"
131124

132125

133126
def test_accuracy():

utils/metrics/precision.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, num_classes: int, macro_averaging: bool = False):
1919
self.num_classes = num_classes
2020
self.macro_averaging = macro_averaging
2121

22-
def forward(self, y_true: torch.tensor, y_pred: torch.tensor) -> torch.tensor:
22+
def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
2323
"""Compute precision of model
2424
2525
Parameters
@@ -34,6 +34,7 @@ def forward(self, y_true: torch.tensor, y_pred: torch.tensor) -> torch.tensor:
3434
torch.tensor
3535
Precision score
3636
"""
37+
y_pred = logits.argmax(dim=-1)
3738
return (
3839
self._macro_avg_precision(y_true, y_pred)
3940
if self.macro_averaging
@@ -57,6 +58,7 @@ def _micro_avg_precision(
5758
torch.tensor
5859
Micro-averaged precision
5960
"""
61+
print(y_true.shape)
6062
true_oh = torch.zeros(y_true.size(0), self.num_classes).scatter_(
6163
1, y_true.unsqueeze(1), 1
6264
)

0 commit comments

Comments
 (0)