Skip to content

Commit bab6aee

Browse files
committed
Add test for metricwrapper and all metrics
Note that the precision metric needs a rewording as we use the argument macro_averaging = True/False as input but that one has the argument micro_averaging.
1 parent 08aa876 commit bab6aee

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/test_metrics.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,49 @@
1+
from random import randint
2+
3+
import pytest
4+
5+
from utils.load_metric import MetricWrapper
16
from utils.metrics import Accuracy, F1Score, Precision, Recall
27

38

9+
@pytest.mark.parametrize(
10+
"metric, num_classes, macro_averaging",
11+
[
12+
("f1", randint(2, 10), False),
13+
("f1", randint(2, 10), True),
14+
("recall", randint(2, 10), False),
15+
("recall", randint(2, 10), True),
16+
("accuracy", randint(2, 10), False),
17+
("accuracy", randint(2, 10), True),
18+
("precision", randint(2, 10), False),
19+
("precision", randint(2, 10), True),
20+
# TODO: Add test for EntropyPrediction
21+
],
22+
)
23+
def test_metric_wrapper(metric, num_classes, macro_averaging):
24+
import numpy as np
25+
import torch
26+
27+
y_true = torch.arange(num_classes, dtype=torch.int64)
28+
logits = torch.rand(num_classes, num_classes)
29+
30+
metrics = MetricWrapper(
31+
metric,
32+
num_classes=num_classes,
33+
macro_averaging=macro_averaging,
34+
)
35+
36+
metrics(y_true, logits)
37+
score = metrics.accumulate()
38+
metrics.reset()
39+
empty_score = metrics.accumulate()
40+
41+
assert isinstance(score, dict), "Expected a dictionary output."
42+
assert metric in score, f"Expected {metric} metric in the output."
43+
assert score[metric] >= 0, "Expected a non-negative value."
44+
assert np.isnan(empty_score[metric]), "Expected an empty list."
45+
46+
447
def test_recall():
548
import torch
649

0 commit comments

Comments
 (0)