Skip to content

Commit e7ba8a8

Browse files
committed
ruffedisorted
1 parent 7c7a80d commit e7ba8a8

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

tests/test_metrics.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,48 +79,52 @@ def test_f1score():
7979

8080

8181
def test_precision():
82-
import torch
82+
from random import randint
83+
8384
import numpy as np
85+
import torch
8486
from sklearn.metrics import precision_score
85-
from random import randint
86-
87-
C = randint(2, 10) # number of classes
88-
N = randint(2,10*C) # batchsize
89-
y_true = torch.randint(0,C, (N,))
87+
88+
C = randint(2, 10) # number of classes
89+
N = randint(2, 10 * C) # batchsize
90+
y_true = torch.randint(0, C, (N,))
9091
logits = torch.randn(N, C)
91-
92+
9293
# create metric objects
9394
precision_micro = Precision(num_classes=C)
9495
precision_macro = Precision(num_classes=C, macro_averaging=True)
95-
96+
9697
# find scores
9798
micro_precision_score = precision_micro(y_true, logits)
9899
macro_precision_score = precision_macro(y_true, logits)
99-
100+
100101
# check output to be tensor
101102
assert isinstance(micro_precision_score, torch.Tensor), "Tensor output is expected."
102103
assert isinstance(macro_precision_score, torch.Tensor), "Tensor output is expected."
103-
104+
104105
# check for non-negativity
105106
assert micro_precision_score.item() >= 0, "Expected non-negative value"
106107
assert macro_precision_score.item() >= 0, "Expected non-negative value"
107-
108+
108109
# find predictions
109110
y_pred = logits.argmax(dim=-1, keepdims=True)
110-
111+
111112
# check dimension
112-
assert y_true.shape == torch.Size([N,1]) or torch.Size([N])
113-
assert logits.shape == torch.Size([N,C])
114-
assert y_pred.shape == torch.Size([N,1]) or torch.Size([N])
113+
assert y_true.shape == torch.Size([N, 1]) or torch.Size([N])
114+
assert logits.shape == torch.Size([N, C])
115+
assert y_pred.shape == torch.Size([N, 1]) or torch.Size([N])
115116

116-
117117
# find true values with scikit learn
118118
scikit_macro_precision = precision_score(y_true, y_pred, average="macro")
119119
scikit_micro_precision = precision_score(y_true, y_pred, average="micro")
120-
120+
121121
# check for similarity
122-
assert np.isclose(scikit_micro_precision, micro_precision_score, atol=1e-5), "Score does not match scikit's score"
123-
assert np.isclose(scikit_macro_precision, macro_precision_score, atol=1e-5), "Score does not match scikit's score"
122+
assert np.isclose(scikit_micro_precision, micro_precision_score, atol=1e-5), (
123+
"Score does not match scikit's score"
124+
)
125+
assert np.isclose(scikit_macro_precision, macro_precision_score, atol=1e-5), (
126+
"Score does not match scikit's score"
127+
)
124128

125129

126130
def test_accuracy():

0 commit comments

Comments
 (0)