Skip to content

Commit ca251bc

Browse files
Fix rbf kernel calculation
1 parent 05fe195 commit ca251bc

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

frouros/utils/kernels.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Kernels module."""
2+
3+
import numpy as np # type: ignore
4+
from scipy.spatial.distance import cdist # type: ignore
5+
6+
7+
def rbf_kernel(
8+
X: np.ndarray, Y: np.ndarray, sigma: float = 1.0 # noqa: N803
9+
) -> np.ndarray:
10+
"""Radial basis function kernel between X and Y matrices.
11+
12+
:param X: X matrix
13+
:type X: numpy.ndarray
14+
:param Y: Y matrix
15+
:type Y: numpy.ndarray
16+
:param sigma: sigma value (equivalent to gamma = 1 / (2 * sigma**2))
17+
:type sigma: float
18+
:return: Radial basis kernel matrix
19+
:rtype: numpy.ndarray
20+
"""
21+
return np.exp(-cdist(X, Y, "sqeuclidean") / (2 * sigma**2))

0 commit comments

Comments
 (0)