Skip to content

Commit a4df0f2

Browse files
committed
updated precision test to fit new micro_averaging argument in dataloader
1 parent e10cf73 commit a4df0f2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_precision_case1():
3838
for boolean, true_precision in zip([True, False], [25.0 / 36, 7.0 / 10]):
3939
true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
4040
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
41-
P = Precision(3, use_mean=boolean)
41+
P = Precision(3, micro_averaging=boolean)
4242
precision1 = P(true1, pred1)
4343
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
4444
f"Precision Score: {precision1.item()}"
@@ -51,7 +51,7 @@ def test_precision_case2():
5151
for boolean, true_precision in zip([True, False], [8.0 / 15, 6.0 / 15]):
5252
true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
5353
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
54-
P = Precision(5, use_mean=boolean)
54+
P = Precision(5, micro_averaging=boolean)
5555
precision2 = P(true2, pred2)
5656
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
5757
f"Precision Score: {precision2.item()}"
@@ -64,7 +64,7 @@ def test_precision_case3():
6464
for boolean, true_precision in zip([True, False], [3.0 / 4, 4.0 / 5]):
6565
true3 = torch.tensor([0, 0, 0, 1, 0])
6666
pred3 = torch.tensor([1, 0, 0, 1, 0])
67-
P = Precision(2, use_mean=boolean)
67+
P = Precision(2, micro_averaging=boolean)
6868
precision3 = P(true3, pred3)
6969
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
7070
f"Precision Score: {precision3.item()}"
@@ -77,7 +77,7 @@ def test_for_zero_denominator():
7777
for boolean in [True, False]:
7878
true4 = torch.tensor([1, 1, 1, 1, 1])
7979
pred4 = torch.tensor([0, 0, 0, 0, 0])
80-
P = Precision(2, use_mean=boolean)
80+
P = Precision(2, micro_averaging=boolean)
8181
precision4 = P(true4, pred4)
8282
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
8383
f"Precision Score: {precision4.item()}"

0 commit comments

Comments
 (0)