@@ -78,56 +78,49 @@ def test_f1score():
7878 assert f1_metric .fn .sum ().item () > 0 , "Expected some false negatives."
7979
8080
81- def test_precision_case1 ():
81+ def test_precision ():
8282 import torch
83-
84- for boolean , true_precision in zip ([False , True ], [25.0 / 36 , 7.0 / 10 ]):
85- true1 = torch .tensor ([0 , 1 , 2 , 1 , 0 , 2 , 1 , 0 , 2 , 1 ])
86- pred1 = torch .tensor ([0 , 2 , 1 , 1 , 0 , 2 , 0 , 0 , 2 , 1 ])
87- P = Precision (3 , micro_averaging = boolean )
88- precision1 = P (true1 , pred1 )
89- assert precision1 .allclose (torch .tensor (true_precision ), atol = 1e-5 ), (
90- f"Precision Score: { precision1 .item ()} "
91- )
92-
93-
94- def test_precision_case2 ():
95- import torch
96-
97- for boolean , true_precision in zip ([False , True ], [8.0 / 15 , 6.0 / 15 ]):
98- true2 = torch .tensor ([0 , 1 , 2 , 3 , 4 , 0 , 1 , 2 , 3 , 4 , 0 , 1 , 2 , 3 , 4 ])
99- pred2 = torch .tensor ([0 , 0 , 4 , 3 , 4 , 0 , 4 , 4 , 2 , 3 , 4 , 1 , 2 , 4 , 0 ])
100- P = Precision (5 , micro_averaging = boolean )
101- precision2 = P (true2 , pred2 )
102- assert precision2 .allclose (torch .tensor (true_precision ), atol = 1e-5 ), (
103- f"Precision Score: { precision2 .item ()} "
104- )
105-
106-
107- def test_precision_case3 ():
108- import torch
109-
110- for boolean , true_precision in zip ([False , True ], [3.0 / 4 , 4.0 / 5 ]):
111- true3 = torch .tensor ([0 , 0 , 0 , 1 , 0 ])
112- pred3 = torch .tensor ([1 , 0 , 0 , 1 , 0 ])
113- P = Precision (2 , micro_averaging = boolean )
114- precision3 = P (true3 , pred3 )
115- assert precision3 .allclose (torch .tensor (true_precision ), atol = 1e-5 ), (
116- f"Precision Score: { precision3 .item ()} "
117- )
118-
119-
120- def test_for_zero_denominator ():
121- import torch
122-
123- for boolean in [False , True ]:
124- true4 = torch .tensor ([1 , 1 , 1 , 1 , 1 ])
125- pred4 = torch .tensor ([0 , 0 , 0 , 0 , 0 ])
126- P = Precision (2 , micro_averaging = boolean )
127- precision4 = P (true4 , pred4 )
128- assert precision4 .allclose (torch .tensor (0.0 ), atol = 1e-5 ), (
129- f"Precision Score: { precision4 .item ()} "
130- )
83+ import numpy as np
84+ from sklearn .metrics import precision_score
85+ from random import randint
86+
87+ C = randint (2 , 10 ) # number of classes
88+ N = randint (2 ,10 * C ) # batchsize
89+ y_true = torch .randint (0 ,C , (N ,))
90+ logits = torch .randn (N , C )
91+
92+ # create metric objects
93+ precision_micro = Precision (num_classes = C )
94+ precision_macro = Precision (num_classes = C , macro_averaging = True )
95+
96+ # find scores
97+ micro_precision_score = precision_micro (y_true , logits )
98+ macro_precision_score = precision_macro (y_true , logits )
99+
100+ # check output to be tensor
101+ assert isinstance (micro_precision_score , torch .Tensor ), "Tensor output is expected."
102+ assert isinstance (macro_precision_score , torch .Tensor ), "Tensor output is expected."
103+
104+ # check for non-negativity
105+ assert micro_precision_score .item () >= 0 , "Expected non-negative value"
106+ assert macro_precision_score .item () >= 0 , "Expected non-negative value"
107+
108+ # find predictions
109+ y_pred = logits .argmax (dim = - 1 , keepdims = True )
110+
111+ # check dimension
112+ assert y_true .shape == torch .Size ([N ,1 ]) or torch .Size ([N ])
113+ assert logits .shape == torch .Size ([N ,C ])
114+ assert y_pred .shape == torch .Size ([N ,1 ]) or torch .Size ([N ])
115+
116+
117+ # find true values with scikit learn
118+ scikit_macro_precision = precision_score (y_true , y_pred , average = "macro" )
119+ scikit_micro_precision = precision_score (y_true , y_pred , average = "micro" )
120+
121+ # check for similarity
122+ assert np .isclose (scikit_micro_precision , micro_precision_score , atol = 1e-5 ), "Score does not match scikit's score"
123+ assert np .isclose (scikit_macro_precision , macro_precision_score , atol = 1e-5 ), "Score does not match scikit's score"
131124
132125
133126def test_accuracy ():
0 commit comments