@@ -79,48 +79,52 @@ def test_f1score():
7979
8080
8181def test_precision ():
82- import torch
82+ from random import randint
83+
8384 import numpy as np
85+ import torch
8486 from sklearn .metrics import precision_score
85- from random import randint
86-
87- C = randint (2 , 10 ) # number of classes
88- N = randint (2 ,10 * C ) # batchsize
89- y_true = torch .randint (0 ,C , (N ,))
87+
88+ C = randint (2 , 10 ) # number of classes
89+ N = randint (2 , 10 * C ) # batchsize
90+ y_true = torch .randint (0 , C , (N ,))
9091 logits = torch .randn (N , C )
91-
92+
9293 # create metric objects
9394 precision_micro = Precision (num_classes = C )
9495 precision_macro = Precision (num_classes = C , macro_averaging = True )
95-
96+
9697 # find scores
9798 micro_precision_score = precision_micro (y_true , logits )
9899 macro_precision_score = precision_macro (y_true , logits )
99-
100+
100101 # check output to be tensor
101102 assert isinstance (micro_precision_score , torch .Tensor ), "Tensor output is expected."
102103 assert isinstance (macro_precision_score , torch .Tensor ), "Tensor output is expected."
103-
104+
104105 # check for non-negativity
105106 assert micro_precision_score .item () >= 0 , "Expected non-negative value"
106107 assert macro_precision_score .item () >= 0 , "Expected non-negative value"
107-
108+
108109 # find predictions
109110 y_pred = logits .argmax (dim = - 1 , keepdims = True )
110-
111+
111112 # check dimension
112- assert y_true .shape == torch .Size ([N ,1 ]) or torch .Size ([N ])
113- assert logits .shape == torch .Size ([N ,C ])
114- assert y_pred .shape == torch .Size ([N ,1 ]) or torch .Size ([N ])
113+ assert y_true .shape == torch .Size ([N , 1 ]) or torch .Size ([N ])
114+ assert logits .shape == torch .Size ([N , C ])
115+ assert y_pred .shape == torch .Size ([N , 1 ]) or torch .Size ([N ])
115116
116-
117117 # find true values with scikit learn
118118 scikit_macro_precision = precision_score (y_true , y_pred , average = "macro" )
119119 scikit_micro_precision = precision_score (y_true , y_pred , average = "micro" )
120-
120+
121121 # check for similarity
122- assert np .isclose (scikit_micro_precision , micro_precision_score , atol = 1e-5 ), "Score does not match scikit's score"
123- assert np .isclose (scikit_macro_precision , macro_precision_score , atol = 1e-5 ), "Score does not match scikit's score"
122+ assert np .isclose (scikit_micro_precision , micro_precision_score , atol = 1e-5 ), (
123+ "Score does not match scikit's score"
124+ )
125+ assert np .isclose (scikit_macro_precision , macro_precision_score , atol = 1e-5 ), (
126+ "Score does not match scikit's score"
127+ )
124128
125129
126130def test_accuracy ():
0 commit comments