Skip to content

Commit 01284a3

Browse files
Fix rbf kernel test parameters
1 parent fcda819 commit 01284a3

File tree

1 file changed

+58
-15
lines changed

1 file changed

+58
-15
lines changed

frouros/tests/unit/utils/test_kernels.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,74 @@
88

99
# TODO: Create fixtures for the matrices and the expected kernel values
1010

11+
1112
@pytest.mark.parametrize(
1213
"X, Y, sigma, expected_kernel_value",
1314
[
1415
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 0.5, np.array([[1.0]])),
1516
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 1.0, np.array([[1.0]])),
1617
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 2.0, np.array([[1.0]])),
17-
(np.array([[1, 2, 3]]), np.array([[4, 5, 6]]), 0.5, np.array([[3.53262857e-24]])),
18-
(np.array([[1, 2, 3]]), np.array([[4, 5, 6]]), 1.0, np.array([[1.37095909e-06]])),
18+
(
19+
np.array([[1, 2, 3]]),
20+
np.array([[4, 5, 6]]),
21+
0.5,
22+
np.array([[3.53262857e-24]]),
23+
),
24+
(
25+
np.array([[1, 2, 3]]),
26+
np.array([[4, 5, 6]]),
27+
1.0,
28+
np.array([[1.37095909e-06]]),
29+
),
1930
(np.array([[1, 2, 3]]), np.array([[4, 5, 6]]), 2.0, np.array([[0.03421812]])),
20-
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1, 2, 3], [4, 5, 6]]), 0.5,
21-
np.array([[1.00000000e+00, 3.53262857e-24], [3.53262857e-24, 1.00000000e+00]])),
22-
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1, 2, 3], [4, 5, 6]]), 1.0,
23-
np.array([[1.00000000e+00, 1.37095909e-06], [1.37095909e-06, 1.00000000e+00]])),
24-
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1, 2, 3], [4, 5, 6]]), 2.0,
25-
np.array([[1.00000000e+00, 0.03421812], [0.03421812, 1.00000000e+00]])),
26-
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1.5, 2.5, 3.5], [4.5,5.5, 6.5]]), 0.5,
27-
np.array([[2.23130160e-01, 1.20048180e-32], [5.17555501e-17, 2.23130160e-01]])),
28-
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]), 1.0,
29-
np.array([[6.87289279e-01, 1.04674018e-08], [8.48182352e-05, 6.87289279e-01]])),
30-
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]), 2.0,
31-
np.array([[0.91051036, 0.01011486], [0.09596709, 0.91051036]])),
31+
(
32+
np.array([[1, 2, 3], [4, 5, 6]]),
33+
np.array([[1, 2, 3], [4, 5, 6]]),
34+
0.5,
35+
np.array(
36+
[[1.00000000e00, 3.53262857e-24], [3.53262857e-24, 1.00000000e00]]
37+
),
38+
),
39+
(
40+
np.array([[1, 2, 3], [4, 5, 6]]),
41+
np.array([[1, 2, 3], [4, 5, 6]]),
42+
1.0,
43+
np.array(
44+
[[1.00000000e00, 1.37095909e-06], [1.37095909e-06, 1.00000000e00]]
45+
),
46+
),
47+
(
48+
np.array([[1, 2, 3], [4, 5, 6]]),
49+
np.array([[1, 2, 3], [4, 5, 6]]),
50+
2.0,
51+
np.array([[1.00000000e00, 0.03421812], [0.03421812, 1.00000000e00]]),
52+
),
53+
(
54+
np.array([[1, 2, 3], [4, 5, 6]]),
55+
np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]),
56+
0.5,
57+
np.array(
58+
[[2.23130160e-01, 1.20048180e-32], [5.17555501e-17, 2.23130160e-01]]
59+
),
60+
),
61+
(
62+
np.array([[1, 2, 3], [4, 5, 6]]),
63+
np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]),
64+
1.0,
65+
np.array(
66+
[[6.87289279e-01, 1.04674018e-08], [8.48182352e-05, 6.87289279e-01]]
67+
),
68+
),
69+
(
70+
np.array([[1, 2, 3], [4, 5, 6]]),
71+
np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]),
72+
2.0,
73+
np.array([[0.91051036, 0.01011486], [0.09596709, 0.91051036]]),
74+
),
3275
],
3376
)
3477
def test_rbf_kernel(
35-
X: np.ndarray,
78+
X: np.ndarray, # noqa: N803
3679
Y: np.ndarray,
3780
sigma: float,
3881
expected_kernel_value: np.ndarray,

0 commit comments

Comments
 (0)