Skip to content

Commit 31579ce

Browse files
Add test_kernels.py file
1 parent ca251bc commit 31579ce

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Test kernels module."""
2+
3+
import numpy as np # type: ignore
4+
import pytest # type: ignore
5+
6+
from frouros.utils.kernels import rbf_kernel
7+
8+
9+
# TODO: Create fixtures for the matrices and the expected kernel values
10+
11+
@pytest.mark.parametrize(
12+
"X, Y, sigma, expected_kernel_value",
13+
[
14+
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 0.5, np.array([[1.0]])),
15+
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 1.0, np.array([[1.0]])),
16+
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 2.0, np.array([[1.0]])),
17+
(np.array([[1, 2, 3]]), np.array([[4, 5, 6]]), 0.5, np.array([[3.53262857e-24]])),
18+
(np.array([[1, 2, 3]]), np.array([[4, 5, 6]]), 1.0, np.array([[1.37095909e-06]])),
19+
(np.array([[1, 2, 3]]), np.array([[4, 5, 6]]), 2.0, np.array([[0.03421812]])),
20+
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1, 2, 3], [4, 5, 6]]), 0.5,
21+
np.array([[1.00000000e+00, 3.53262857e-24], [3.53262857e-24, 1.00000000e+00]])),
22+
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1, 2, 3], [4, 5, 6]]), 1.0,
23+
np.array([[1.00000000e+00, 1.37095909e-06], [1.37095909e-06, 1.00000000e+00]])),
24+
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1, 2, 3], [4, 5, 6]]), 2.0,
25+
np.array([[1.00000000e+00, 0.03421812], [0.03421812, 1.00000000e+00]])),
26+
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1.5, 2.5, 3.5], [4.5,5.5, 6.5]]), 0.5,
27+
np.array([[2.23130160e-01, 1.20048180e-32], [5.17555501e-17, 2.23130160e-01]])),
28+
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]), 1.0,
29+
np.array([[6.87289279e-01, 1.04674018e-08], [8.48182352e-05, 6.87289279e-01]])),
30+
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]), 2.0,
31+
np.array([[0.91051036, 0.01011486], [0.09596709, 0.91051036]])),
32+
],
33+
)
34+
def test_rbf_kernel(
35+
X: np.ndarray,
36+
Y: np.ndarray,
37+
sigma: float,
38+
expected_kernel_value: np.ndarray,
39+
) -> None:
40+
"""Test rbf kernel.
41+
42+
:param X: X values
43+
:type X: numpy.ndarray
44+
:param Y: Y values
45+
:type Y: numpy.ndarray
46+
:param sigma: sigma value
47+
:type sigma: float
48+
:param expected_kernel_value: expected kernel value
49+
:type expected_kernel_value: numpy.ndarray
50+
"""
51+
assert np.all(
52+
np.isclose(
53+
rbf_kernel(
54+
X=X,
55+
Y=Y,
56+
sigma=sigma,
57+
),
58+
expected_kernel_value,
59+
),
60+
)

0 commit comments

Comments
 (0)