Skip to content

Commit f6259ac

Browse files
committed
ENH: Implement squared exponential covariance function
Implement squared exponential covariance function.
1 parent 628aa8f commit f6259ac

File tree

2 files changed

+249
-2
lines changed

2 files changed

+249
-2
lines changed

src/nifreeze/model/gpr.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import numpy.typing as npt
3232
from scipy import optimize
3333
from scipy.optimize import Bounds
34+
from scipy.spatial.distance import cdist, pdist, squareform
3435
from sklearn.gaussian_process import GaussianProcessRegressor
3536
from sklearn.gaussian_process.kernels import (
3637
Hyperparameter,
@@ -41,6 +42,8 @@
4142

4243
BOUNDS_A: tuple[float, float] = (0.1, 2.35)
4344
"""The limits for the parameter *a* (angular distance in rad)."""
45+
BOUNDS_ELL: tuple[float, float] = (0.1, 2.35)
46+
r"""The limits for the parameter *$\ell$* (shell distance in $s/mm^2$)."""
4447
BOUNDS_LAMBDA: tuple[float, float] = (1e-3, 1000)
4548
"""The limits for the parameter λ (signal scaling factor)."""
4649
THETA_EPSILON: float = 1e-5
@@ -472,6 +475,121 @@ def __repr__(self) -> str:
472475
return f"SphericalKriging (a={self.beta_a}, λ={self.beta_l})"
473476

474477

478+
class SquaredExponentialKriging(Kernel):
479+
"""A scikit-learn's kernel for DWI signals."""
480+
481+
def __init__(
482+
self,
483+
beta_ell: float = 1.38,
484+
beta_l: float = 0.5,
485+
ell_bounds: tuple[float, float] = BOUNDS_ELL,
486+
l_bounds: tuple[float, float] = BOUNDS_LAMBDA,
487+
):
488+
r"""
489+
Initialize a spherical Kriging kernel.
490+
491+
Parameters
492+
----------
493+
beta_ell : :obj:`float`, optional
494+
Minimum angle in rads.
495+
beta_l : :obj:`float`, optional
496+
The :math:`\lambda` hyperparameter.
497+
ell_bounds : :obj:`tuple`, optional
498+
Bounds for the :math:`\ell` parameter.
499+
l_bounds : :obj:`tuple`, optional
500+
Bounds for the :math:`\lambda` hyperparameter.
501+
502+
"""
503+
self.beta_ell = beta_ell
504+
self.beta_l = beta_l
505+
self.a_bounds = ell_bounds
506+
self.l_bounds = l_bounds
507+
508+
@property
509+
def hyperparameter_ell(self) -> Hyperparameter:
510+
return Hyperparameter("beta_ell", "numeric", self.a_bounds)
511+
512+
@property
513+
def hyperparameter_l(self) -> Hyperparameter:
514+
return Hyperparameter("beta_l", "numeric", self.l_bounds)
515+
516+
def __call__(
517+
self, X: np.ndarray, Y: np.ndarray | None = None, eval_gradient: bool = False
518+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
519+
"""
520+
Return the kernel K(X, Y) and optionally its gradient.
521+
522+
Parameters
523+
----------
524+
X : :obj:`~numpy.ndarray`
525+
Gradient wighting values (X)
526+
Y : :obj:`~numpy.ndarray`, optional
527+
Gradient wighting values (Y, optional)
528+
eval_gradient : :obj:`bool`, optional
529+
Determines whether the gradient with respect to the log of
530+
the kernel hyperparameter is computed.
531+
Only supported when Y is ``None``.
532+
533+
Returns
534+
-------
535+
K : :obj:`~numpy.ndarray` of shape (n_samples_X, n_samples_Y)
536+
Kernel k(X, Y)
537+
538+
K_gradient : :obj:`~numpy.ndarray` of shape (n_samples_X, n_samples_X, n_dims),\
539+
optional
540+
The gradient of the kernel k(X, X) with respect to the log of the
541+
hyperparameter of the kernel. Only returned when ``eval_gradient``
542+
is True.
543+
544+
"""
545+
546+
dists = compute_shell_distance(X, Y=Y)
547+
C_b = squared_exponential_covariance(dists, self.beta_ell)
548+
549+
if Y is None:
550+
C_b = squareform(C_b)
551+
np.fill_diagonal(C_b, 1)
552+
553+
if not eval_gradient:
554+
return self.beta_l * C_b
555+
556+
# Looking at this
557+
# https://github.com/scikit-learn/scikit-learn/blob/1e6a81f322f1821cc605a18b08fcc198c7d63c97/sklearn/gaussian_process/kernels.py#L1574
558+
# Not sure the derivative is clear to me. IMO it should be
559+
# \frac{d}{dx} \left( e^{-\frac{cte}{x^2}} \right) = \frac{2 cte}{x^3} \cdot e^{-\frac{cte}{x^2}}
560+
# where x is ell, and cte is 0.5 * (\log b - \log b')^2
561+
562+
K_gradient = 1 # ToDo
563+
564+
return self.beta_l * C_b, K_gradient
565+
566+
def diag(self, X: np.ndarray) -> np.ndarray:
567+
"""Returns the diagonal of the kernel k(X, X).
568+
569+
The result of this method is identical to np.diag(self(X)); however,
570+
it can be evaluated more efficiently since only the diagonal is
571+
evaluated.
572+
573+
Parameters
574+
----------
575+
X : :obj:`~numpy.ndarray` of shape (n_samples_X, n_features)
576+
Left argument of the returned kernel k(X, Y)
577+
578+
Returns
579+
-------
580+
K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,)
581+
Diagonal of kernel k(X, X)
582+
"""
583+
return self.beta_l * np.ones(X.shape[0])
584+
585+
def is_stationary(self) -> bool:
586+
"""Returns whether the kernel is stationary."""
587+
return True
588+
589+
def __repr__(self) -> str:
590+
return f"SquaredExponentialKriging (wll={self.beta_ell}, λ={self.beta_l})"
591+
592+
475593
def exponential_covariance(theta: np.ndarray, a: float) -> np.ndarray:
476594
r"""
477595
Compute the exponential covariance for given distances and scale parameter.
@@ -593,3 +711,71 @@ def compute_pairwise_angles(
593711
thetas = np.arccos(np.abs(cosines)) if closest_polarity else np.arccos(cosines)
594712
thetas[np.abs(thetas) < THETA_EPSILON] = 0.0
595713
return thetas
714+
715+
716+
def squared_exponential_covariance(
717+
shell_distance: np.ndarray,
718+
ell: float,
719+
) -> np.ndarray:
720+
r"""Compute the squared exponential covariance for given diffusion gradient
721+
encoding weighting distances and scale parameter.
722+
723+
Implements :math:`C_{b}`, following Eq. (15) in [Andersson15]_:
724+
725+
.. math::
726+
727+
C_{b}(b, b'; \ell) = \exp\left( - \frac{(\log b - \log b')^2}{2 \ell^2} \right)
728+
729+
The squared exponential covariance function is sometimes called radial basis
730+
function (RBF) or Gaussian kernel.
731+
732+
Parameters
733+
----------
734+
shell_distance : :obj:`~numpy.ndarray` of shape (n_samples_X, n_features)
735+
Input data.
736+
ell : float
737+
Distance parameter where the covariance function goes to zero.
738+
739+
Returns
740+
-------
741+
:obj:`~numpy.ndarray`
742+
Squared exponential covariance values for the input distances.
743+
"""
744+
745+
return np.exp(-0.5 * (shell_distance / (ell**2)))
746+
747+
748+
def compute_shell_distance(X, Y=None):
749+
r"""Compute pairwise angles across diffusion gradient encoding weighting
750+
values.
751+
752+
Following Eq. (15) in [Andersson15]_, computes the distance between the log
753+
values of the diffusion gradient encoding weighting values:
754+
755+
.. math::
756+
757+
\log b - \log b'
758+
759+
Parameters
760+
----------
761+
X : :obj:`~numpy.ndarray` of shape (n_samples_X, n_features)
762+
Input data.
763+
Y : :obj:`~numpy.ndarray` of shape (n_samples_Y, n_features), optional
764+
Input data. If ``None``, the output will be the pairwise
765+
similarities between all samples in ``X``.
766+
767+
Returns
768+
-------
769+
:obj:`~numpy.ndarray`
770+
Pairwise distances of diffusion gradient encoding weighting values.
771+
"""
772+
773+
# ToDo
774+
# scikit-learn RBF call includes $\ell$ here; fine, but then I do not get
775+
# the derivative computation the way they compute it
776+
if Y is None:
777+
dists = pdist(np.log(X), metric="sqeuclidean")
778+
else:
779+
dists = cdist(np.log(X), np.log(Y), metric="sqeuclidean")
780+
781+
return dists

test/test_gpr.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,8 @@ def test_compute_pairwise_angles(bvecs1, bvecs2, closest_polarity, expected):
267267
np.testing.assert_array_almost_equal(obtained, expected, decimal=2)
268268

269269

270-
@pytest.mark.parametrize("covariance", ["Spherical", "Exponential"])
271-
def test_kernel(repodata, covariance):
270+
@pytest.mark.parametrize("covariance", ["Spherical", "Exponential", "SquaredExponential"])
271+
def test_kernel_single_shell(repodata, covariance):
272272
"""Check kernel construction."""
273273

274274
bvals, bvecs = read_bvals_bvecs(
@@ -292,3 +292,64 @@ def test_kernel(repodata, covariance):
292292

293293
K_predict = kernel(bvecs, bvecs[10:14, ...])
294294
assert K_predict.shape == (K.shape[0], 4)
295+
296+
297+
# ToDo
298+
@pytest.mark.parametrize(
299+
("bvals1", "bvals2", "expected"),
300+
[
301+
(
302+
np.array(
303+
[
304+
[1000, 1000, 1000, 1000],
305+
]
306+
),
307+
None,
308+
np.array(
309+
[
310+
[0, 0, 0, 0],
311+
]
312+
),
313+
),
314+
(
315+
np.array(
316+
[
317+
[1000, 1000, 1000, 1000],
318+
[2000, 2000, 2000],
319+
]
320+
),
321+
None,
322+
np.array(
323+
[
324+
[1000, 1000, 1000, 1000],
325+
]
326+
),
327+
),
328+
],
329+
)
330+
def test_compute_shell_distance(bvals1, bvals2, expected):
331+
obtained = gpr.compute_shell_distance(bvals1, bvals2)
332+
333+
if bvals2 is not None:
334+
assert (bvals1.shape[0], bvals2.shape[0]) == obtained.shape
335+
assert obtained.shape == expected.shape
336+
np.testing.assert_array_almost_equal(obtained, expected, decimal=2)
337+
338+
339+
@pytest.mark.parametrize("covariance", ["SquaredExponential"])
340+
def test_kernel_multi_shell(repodata, covariance):
341+
"""Check kernel construction."""
342+
343+
bvals, bvecs = read_bvals_bvecs(
344+
str(repodata / "ds000114_multishell.bval"),
345+
str(repodata / "ds000114_multishell.bvec"),
346+
)
347+
348+
bvals = bvals[bvals > 10]
349+
350+
KernelType = getattr(gpr, f"{covariance}Kriging")
351+
kernel = KernelType()
352+
353+
K = kernel(bvals)
354+
355+
assert K.shape == (bvals.shape[0],) * 2

0 commit comments

Comments
 (0)