Skip to content

Commit 1574c7e

Browse files
TST: Add tests for metric utilities and scorers
1 parent 37c33e2 commit 1574c7e

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""Tests for the metrics module utilities."""
2+
3+
import numpy.testing as npt
4+
import pytest
5+
6+
from orca_python.metrics import (
7+
accuracy_off1,
8+
amae,
9+
ccr,
10+
gm,
11+
gmsec,
12+
mae,
13+
mmae,
14+
ms,
15+
mze,
16+
rps,
17+
spearman,
18+
tkendall,
19+
wkappa,
20+
)
21+
from orca_python.metrics.utils import (
22+
_METRICS,
23+
compute_metric,
24+
get_metric_names,
25+
greater_is_better,
26+
load_metric_as_scorer,
27+
)
28+
29+
30+
def test_get_metric_names():
31+
"""Test that get_metric_names returns all available metric names."""
32+
all_metrics = get_metric_names()
33+
expected_names = list(_METRICS.keys())
34+
35+
assert type(all_metrics) is list
36+
assert all_metrics[:3] == ["accuracy_off1", "amae", "ccr"]
37+
assert "rps" in all_metrics
38+
npt.assert_array_equal(sorted(all_metrics), sorted(expected_names))
39+
40+
41+
@pytest.mark.parametrize(
42+
"metric_name, gib",
43+
[
44+
("accuracy_off1", True),
45+
("amae", False),
46+
("ccr", True),
47+
("gm", True),
48+
("gmsec", True),
49+
("mae", False),
50+
("mmae", False),
51+
("ms", True),
52+
("mze", False),
53+
("rps", False),
54+
("spearman", True),
55+
("tkendall", True),
56+
("wkappa", True),
57+
],
58+
)
59+
def test_greater_is_better(metric_name, gib):
60+
"""Test that greater_is_better returns the correct boolean for each metric."""
61+
assert greater_is_better(metric_name) == gib
62+
63+
64+
def test_greater_is_better_invalid_name():
65+
"""Test that greater_is_better raises an error for an invalid metric name."""
66+
error_msg = "Unrecognized metric name: 'roc_auc'."
67+
68+
with pytest.raises(KeyError, match=error_msg):
69+
greater_is_better("roc_auc")
70+
71+
72+
@pytest.mark.parametrize(
73+
"metric_name, metric",
74+
[
75+
("rps", rps),
76+
("ccr", ccr),
77+
("accuracy_off1", accuracy_off1),
78+
("gm", gm),
79+
("gmsec", gmsec),
80+
("mae", mae),
81+
("mmae", mmae),
82+
("amae", amae),
83+
("ms", ms),
84+
("mze", mze),
85+
("tkendall", tkendall),
86+
("wkappa", wkappa),
87+
("spearman", spearman),
88+
],
89+
)
90+
def test_load_metric_as_scorer(metric_name, metric):
91+
"""Test that load_metric_as_scorer correctly loads the expected metric."""
92+
metric_func = load_metric_as_scorer(metric_name)
93+
94+
assert metric_func._score_func == metric
95+
assert metric_func._sign == (1 if greater_is_better(metric_name) else -1)
96+
97+
98+
@pytest.mark.parametrize(
99+
"metric_name, metric",
100+
[
101+
("ccr", ccr),
102+
("accuracy_off1", accuracy_off1),
103+
("gm", gm),
104+
("gmsec", gmsec),
105+
("mae", mae),
106+
("mmae", mmae),
107+
("amae", amae),
108+
("ms", ms),
109+
("mze", mze),
110+
("tkendall", tkendall),
111+
("wkappa", wkappa),
112+
("spearman", spearman),
113+
],
114+
)
115+
def test_correct_metric_output(metric_name, metric):
116+
"""Test that the loaded metric function produces the same output as the
117+
original metric."""
118+
y_true = [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
119+
y_pred = [1, 3, 3, 1, 2, 3, 1, 2, 2, 1, 3, 1, 1, 2, 2, 2, 3, 3, 1, 3]
120+
metric_func = load_metric_as_scorer(metric_name)
121+
metric_true = metric(y_true, y_pred)
122+
metric_pred = metric_func._score_func(y_true, y_pred)
123+
124+
npt.assert_almost_equal(metric_pred, metric_true, decimal=6)
125+
126+
127+
def test_load_metric_invalid_name():
128+
"""Test that loading an invalid metric raises the correct exception."""
129+
error_msg = "metric_name must be a string."
130+
with pytest.raises(TypeError, match=error_msg):
131+
load_metric_as_scorer(123)
132+
133+
error_msg = "Unrecognized metric name: 'roc_auc'."
134+
with pytest.raises(KeyError, match=error_msg):
135+
load_metric_as_scorer("roc_auc")
136+
137+
138+
@pytest.mark.parametrize(
139+
"metric_name",
140+
[
141+
"ccr",
142+
"accuracy_off1",
143+
"gm",
144+
"gmsec",
145+
"mae",
146+
"mmae",
147+
"amae",
148+
"ms",
149+
"mze",
150+
"tkendall",
151+
"wkappa",
152+
"spearman",
153+
],
154+
)
155+
def test_compute_metric(metric_name) -> None:
156+
"""Test that compute_metric returns the correct metric value."""
157+
y_true = [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
158+
y_pred = [1, 3, 3, 1, 2, 3, 1, 2, 2, 1, 3, 1, 1, 2, 2, 2, 3, 3, 1, 3]
159+
metric_value = compute_metric(metric_name, y_true, y_pred)
160+
metric_func = load_metric_as_scorer(metric_name)
161+
metric_true = metric_func._score_func(y_true, y_pred)
162+
163+
npt.assert_almost_equal(metric_value, metric_true, decimal=6)
164+
165+
166+
def test_compute_metric_invalid_name():
167+
"""Test that compute_metric raises an error for an invalid metric name."""
168+
error_msg = "Unrecognized metric name: 'roc_auc'."
169+
170+
with pytest.raises(KeyError, match=error_msg):
171+
compute_metric("roc_auc", [1, 2, 3], [1, 2, 3])

0 commit comments

Comments
 (0)