Skip to content

Commit ea25e65

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #247 from IFCA/fix-rbf-kernel
Fix rbf kernel
2 parents 05fe195 + 01284a3 commit ea25e65

File tree

4 files changed

+153
-40
lines changed

4 files changed

+153
-40
lines changed

docs/source/examples/data_drift/MMD_simple.ipynb

Lines changed: 28 additions & 22 deletions
Large diffs are not rendered by default.

frouros/detectors/data_drift/batch/distance_based/mmd.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,14 @@
66

77
import numpy as np # type: ignore
88
import tqdm # type: ignore
9-
from scipy.spatial.distance import cdist # type: ignore
109

1110
from frouros.callbacks.batch.base import BaseCallbackBatch
1211
from frouros.detectors.data_drift.base import MultivariateData
1312
from frouros.detectors.data_drift.batch.distance_based.base import (
1413
BaseDistanceBased,
1514
DistanceResult,
1615
)
17-
18-
19-
def rbf_kernel(
20-
X: np.ndarray, Y: np.ndarray, std: float = 1.0 # noqa: N803
21-
) -> np.ndarray:
22-
"""Radial basis function kernel between X and Y matrices.
23-
24-
:param X: X matrix
25-
:type X: numpy.ndarray
26-
:param Y: Y matrix
27-
:type Y: numpy.ndarray
28-
:param std: standard deviation value
29-
:type std: float
30-
:return: Radial basis kernel matrix
31-
:rtype: numpy.ndarray
32-
"""
33-
return np.exp(-cdist(X, Y, "sqeuclidean") / 2 * std**2)
16+
from frouros.utils.kernels import rbf_kernel
3417

3518

3619
class MMD(BaseDistanceBased):
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""Test kernels module."""
2+
3+
import numpy as np # type: ignore
4+
import pytest # type: ignore
5+
6+
from frouros.utils.kernels import rbf_kernel
7+
8+
9+
# TODO: Create fixtures for the matrices and the expected kernel values
10+
11+
12+
@pytest.mark.parametrize(
13+
"X, Y, sigma, expected_kernel_value",
14+
[
15+
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 0.5, np.array([[1.0]])),
16+
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 1.0, np.array([[1.0]])),
17+
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 2.0, np.array([[1.0]])),
18+
(
19+
np.array([[1, 2, 3]]),
20+
np.array([[4, 5, 6]]),
21+
0.5,
22+
np.array([[3.53262857e-24]]),
23+
),
24+
(
25+
np.array([[1, 2, 3]]),
26+
np.array([[4, 5, 6]]),
27+
1.0,
28+
np.array([[1.37095909e-06]]),
29+
),
30+
(np.array([[1, 2, 3]]), np.array([[4, 5, 6]]), 2.0, np.array([[0.03421812]])),
31+
(
32+
np.array([[1, 2, 3], [4, 5, 6]]),
33+
np.array([[1, 2, 3], [4, 5, 6]]),
34+
0.5,
35+
np.array(
36+
[[1.00000000e00, 3.53262857e-24], [3.53262857e-24, 1.00000000e00]]
37+
),
38+
),
39+
(
40+
np.array([[1, 2, 3], [4, 5, 6]]),
41+
np.array([[1, 2, 3], [4, 5, 6]]),
42+
1.0,
43+
np.array(
44+
[[1.00000000e00, 1.37095909e-06], [1.37095909e-06, 1.00000000e00]]
45+
),
46+
),
47+
(
48+
np.array([[1, 2, 3], [4, 5, 6]]),
49+
np.array([[1, 2, 3], [4, 5, 6]]),
50+
2.0,
51+
np.array([[1.00000000e00, 0.03421812], [0.03421812, 1.00000000e00]]),
52+
),
53+
(
54+
np.array([[1, 2, 3], [4, 5, 6]]),
55+
np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]),
56+
0.5,
57+
np.array(
58+
[[2.23130160e-01, 1.20048180e-32], [5.17555501e-17, 2.23130160e-01]]
59+
),
60+
),
61+
(
62+
np.array([[1, 2, 3], [4, 5, 6]]),
63+
np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]),
64+
1.0,
65+
np.array(
66+
[[6.87289279e-01, 1.04674018e-08], [8.48182352e-05, 6.87289279e-01]]
67+
),
68+
),
69+
(
70+
np.array([[1, 2, 3], [4, 5, 6]]),
71+
np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]),
72+
2.0,
73+
np.array([[0.91051036, 0.01011486], [0.09596709, 0.91051036]]),
74+
),
75+
],
76+
)
77+
def test_rbf_kernel(
78+
X: np.ndarray, # noqa: N803
79+
Y: np.ndarray,
80+
sigma: float,
81+
expected_kernel_value: np.ndarray,
82+
) -> None:
83+
"""Test rbf kernel.
84+
85+
:param X: X values
86+
:type X: numpy.ndarray
87+
:param Y: Y values
88+
:type Y: numpy.ndarray
89+
:param sigma: sigma value
90+
:type sigma: float
91+
:param expected_kernel_value: expected kernel value
92+
:type expected_kernel_value: numpy.ndarray
93+
"""
94+
assert np.all(
95+
np.isclose(
96+
rbf_kernel(
97+
X=X,
98+
Y=Y,
99+
sigma=sigma,
100+
),
101+
expected_kernel_value,
102+
),
103+
)

frouros/utils/kernels.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Kernels module."""
2+
3+
import numpy as np # type: ignore
4+
from scipy.spatial.distance import cdist # type: ignore
5+
6+
7+
def rbf_kernel(
8+
X: np.ndarray, Y: np.ndarray, sigma: float = 1.0 # noqa: N803
9+
) -> np.ndarray:
10+
"""Radial basis function kernel between X and Y matrices.
11+
12+
:param X: X matrix
13+
:type X: numpy.ndarray
14+
:param Y: Y matrix
15+
:type Y: numpy.ndarray
16+
:param sigma: sigma value (equivalent to gamma = 1 / (2 * sigma**2))
17+
:type sigma: float
18+
:return: Radial basis kernel matrix
19+
:rtype: numpy.ndarray
20+
"""
21+
return np.exp(-cdist(X, Y, "sqeuclidean") / (2 * sigma**2))

0 commit comments

Comments
 (0)