Skip to content

Commit deb90c2

Browse files
m-juliangpleiss
andauthored
keops periodic and keops kernels unit tests (#2296)
* working version * change name of import * add periodic kernel * copy original periodic kernel code * tests for keops periodic kernel * addtests * remove copied file * fixed NaN * add more tests * update docstring * add nonkeops tests * add cuda check * formatting * use arithmetic operators * add tests * use cuda tensors * subcclass from periodic kernel * docstring update * base keops class for tests * run keops tests on cpu * formatting * use KernelLinearOperator * add comment * gradient and ard tests * another gradient test * diag and refactor * Update test cases, adapt to new KernelLinearOperator style * Update to latest version of LinearOperator, add keops tests to CI * Add behavioral test for KeOps regression * Include KeOps kernels in the docs * Refactor keops implementation, add more testing --------- Co-authored-by: Geoff Pleiss <[email protected]>
1 parent 413898a commit deb90c2

File tree

16 files changed

+560
-211
lines changed

16 files changed

+560
-211
lines changed

.conda/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ requirements:
1818
run:
1919
- pytorch>=1.11
2020
- scikit-learn
21-
- linear_operator>=0.4.0
21+
- linear_operator>=0.5.0
2222

2323
test:
2424
imports:

.github/workflows/run_test_suite.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ jobs:
5555
pip install -e .
5656
if [[ ${{ matrix.extras }} == "with-extras" ]]; then
5757
pip install "pyro-ppl>=1.8";
58+
pip install pykeops;
5859
pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history
5960
fi
6061
- name: Run unit tests
@@ -75,7 +76,8 @@ jobs:
7576
pip install pytest nbval jupyter tqdm matplotlib torchvision scipy
7677
pip install -e .
7778
pip install "pyro-ppl>=1.8";
78-
pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history
79+
pip install pykeops;
80+
pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history
7981
- name: Run example notebooks
8082
run: |
8183
grep -l smoke_test examples/**/*.ipynb | xargs grep -L 'smoke_test = False' | CI=true xargs pytest --nbval-lax --current-env

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
setuptools_scm<=7.1.0
22
ipython<=8.6.0
33
ipykernel<=6.17.1
4-
linear_operator>=0.4.0
4+
linear_operator>=0.5.0
55
m2r2<=0.3.3.post2
66
nbclient<=0.7.3
77
nbformat<=5.8.0

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ GPyTorch's documentation
2929
models
3030
likelihoods
3131
kernels
32+
keops_kernels
3233
means
3334
marginal_log_likelihoods
3435
metrics

docs/source/keops_kernels.rst

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
.. role:: hidden
2+
:class: hidden-section
3+
4+
gpytorch.kernels.keops
5+
===================================
6+
7+
.. automodule:: gpytorch.kernels.keops
8+
.. currentmodule:: gpytorch.kernels.keops
9+
10+
11+
These kernels are compatible with the GPyTorch KeOps integration.
12+
For more information, see the `KeOps tutorial`_.
13+
14+
.. note::
15+
Only some standard kernels have KeOps impementations.
16+
If there is a kernel you want that's missing, consider submitting a pull request!
17+
18+
19+
.. _KeOps Tutorial:
20+
examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.html
21+
22+
23+
:hidden:`RBFKernel`
24+
~~~~~~~~~~~~~~~~~~~~~~
25+
26+
.. autoclass:: RBFKernel
27+
:members:
28+
29+
30+
:hidden:`MaternKernel`
31+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
32+
33+
.. autoclass:: MaternKernel
34+
:members:
35+
36+
37+
:hidden:`PeriodicKernel`
38+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
39+
40+
.. autoclass:: PeriodicKernel
41+
:members:

gpytorch/kernels/keops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .matern_kernel import MaternKernel
2+
from .periodic_kernel import PeriodicKernel
23
from .rbf_kernel import RBFKernel
34

4-
__all__ = ["MaternKernel", "RBFKernel"]
5+
__all__ = ["MaternKernel", "RBFKernel", "PeriodicKernel"]
Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,41 @@
11
from abc import abstractmethod
2+
from typing import Any
23

34
import torch
5+
from torch import Tensor
46

7+
from ... import settings
58
from ..kernel import Kernel
69

710
try:
8-
from pykeops.torch import LazyTensor as KEOLazyTensor
11+
import pykeops # noqa F401
912

1013
class KeOpsKernel(Kernel):
1114
@abstractmethod
12-
def covar_func(self, x1: torch.Tensor, x2: torch.Tensor) -> KEOLazyTensor:
13-
raise NotImplementedError("KeOpsKernels must define a covar_func method")
15+
def _nonkeops_forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any):
16+
r"""
17+
Computes the covariance matrix (or diagonal) without using KeOps.
18+
This function must implement both the diag=True and diag=False options.
19+
"""
20+
raise NotImplementedError
1421

15-
def __call__(self, *args, **kwargs):
22+
@abstractmethod
23+
def _keops_forward(self, x1: Tensor, x2: Tensor, **kwargs: Any):
24+
r"""
25+
Computes the covariance matrix with KeOps.
26+
This function only implements the diag=False option, and no diag keyword should be passed in.
27+
"""
28+
raise NotImplementedError
29+
30+
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any):
31+
if diag:
32+
return self._nonkeops_forward(x1, x2, diag=True, **kwargs)
33+
elif x1.size(-2) < settings.max_cholesky_size.value() or x2.size(-2) < settings.max_cholesky_size.value():
34+
return self._nonkeops_forward(x1, x2, diag=False, **kwargs)
35+
else:
36+
return self._keops_forward(x1, x2, **kwargs)
37+
38+
def __call__(self, *args: Any, **kwargs: Any):
1639
# Hotfix for zero gradients. See https://github.com/cornellius-gp/gpytorch/issues/1543
1740
args = [arg.contiguous() if torch.is_tensor(arg) else arg for arg in args]
1841
kwargs = {k: v.contiguous() if torch.is_tensor(v) else v for k, v in kwargs.items()}
@@ -21,5 +44,5 @@ def __call__(self, *args, **kwargs):
2144
except ImportError:
2245

2346
class KeOpsKernel(Kernel):
24-
def __init__(self, *args, **kwargs):
47+
def __init__(self, *args: Any, **kwargs: Any):
2548
raise RuntimeError("You must have KeOps installed to use a KeOpsKernel")

gpytorch/kernels/keops/matern_kernel.py

Lines changed: 54 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,57 @@
22
import math
33

44
import torch
5-
from linear_operator.operators import KeOpsLinearOperator
5+
from linear_operator.operators import KernelLinearOperator
66

7-
from ... import settings
7+
from ..matern_kernel import MaternKernel as GMaternKernel
88
from .keops_kernel import KeOpsKernel
99

1010
try:
1111
from pykeops.torch import LazyTensor as KEOLazyTensor
1212

13+
def _covar_func(x1, x2, nu=2.5, **params):
14+
x1_ = KEOLazyTensor(x1[..., :, None, :])
15+
x2_ = KEOLazyTensor(x2[..., None, :, :])
16+
17+
distance = ((x1_ - x2_) ** 2).sum(-1).sqrt()
18+
exp_component = (-math.sqrt(nu * 2) * distance).exp()
19+
20+
if nu == 0.5:
21+
constant_component = 1
22+
elif nu == 1.5:
23+
constant_component = (math.sqrt(3) * distance) + 1
24+
elif nu == 2.5:
25+
constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * (distance**2))
26+
27+
return constant_component * exp_component
28+
1329
class MaternKernel(KeOpsKernel):
1430
"""
1531
Implements the Matern kernel using KeOps as a driver for kernel matrix multiplies.
1632
17-
This class can be used as a drop in replacement for gpytorch.kernels.MaternKernel in most cases, and supports
18-
the same arguments. There are currently a few limitations, for example a lack of batch mode support. However,
19-
most other features like ARD will work.
33+
This class can be used as a drop in replacement for :class:`gpytorch.kernels.MaternKernel` in most cases,
34+
and supports the same arguments.
35+
36+
:param nu: (Default: 2.5) The smoothness parameter.
37+
:type nu: float (0.5, 1.5, or 2.5)
38+
:param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each
39+
input dimension. It should be `d` if x1 is a `... x n x d` matrix.
40+
:type ard_num_dims: int, optional
41+
:param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each
42+
batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output.
43+
:type batch_shape: torch.Size, optional
44+
:param active_dims: (Default: `None`) Set this if you want to
45+
compute the covariance of only a few input dimensions. The ints
46+
corresponds to the indices of the dimensions.
47+
:type active_dims: Tuple(int)
48+
:param lengthscale_prior: (Default: `None`)
49+
Set this if you want to apply a prior to the lengthscale parameter.
50+
:type lengthscale_prior: ~gpytorch.priors.Prior, optional
51+
:param lengthscale_constraint: (Default: `Positive`) Set this if you want
52+
to apply a constraint to the lengthscale parameter.
53+
:type lengthscale_constraint: ~gpytorch.constraints.Interval, optional
54+
:param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors).
55+
:type eps: float, optional
2056
"""
2157

2258
has_lengthscale = True
@@ -27,8 +63,12 @@ def __init__(self, nu=2.5, **kwargs):
2763
super(MaternKernel, self).__init__(**kwargs)
2864
self.nu = nu
2965

30-
def _nonkeops_covar_func(self, x1, x2, diag=False):
31-
distance = self.covar_dist(x1, x2, diag=diag)
66+
def _nonkeops_forward(self, x1, x2, diag=False, **kwargs):
67+
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]
68+
x1_ = (x1 - mean) / self.lengthscale
69+
x2_ = (x2 - mean) / self.lengthscale
70+
71+
distance = self.covar_dist(x1_, x2_, diag=diag, **kwargs)
3272
exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance)
3373

3474
if self.nu == 0.5:
@@ -39,63 +79,14 @@ def _nonkeops_covar_func(self, x1, x2, diag=False):
3979
constant_component = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance**2)
4080
return constant_component * exp_component
4181

42-
def covar_func(self, x1, x2, diag=False):
43-
# We only should use KeOps on big kernel matrices
44-
# If we would otherwise be performing Cholesky inference, (or when just computing a kernel matrix diag)
45-
# then don't apply KeOps
46-
# enable gradients to ensure that test time caches on small predictions are still
47-
# backprop-able
48-
with torch.autograd.enable_grad():
49-
if (
50-
diag
51-
or x1.size(-2) < settings.max_cholesky_size.value()
52-
or x2.size(-2) < settings.max_cholesky_size.value()
53-
):
54-
return self._nonkeops_covar_func(x1, x2, diag=diag)
55-
# TODO: x1 / x2 size checks are a work around for a very minor bug in KeOps.
56-
# This bug is fixed on KeOps master, and we'll remove that part of the check
57-
# when they cut a new release.
58-
elif x1.size(-2) == 1 or x2.size(-2) == 1:
59-
return self._nonkeops_covar_func(x1, x2, diag=diag)
60-
else:
61-
# We only should use KeOps on big kernel matrices
62-
# If we would otherwise be performing Cholesky inference, then don't apply KeOps
63-
if (
64-
x1.size(-2) < settings.max_cholesky_size.value()
65-
or x2.size(-2) < settings.max_cholesky_size.value()
66-
):
67-
x1_ = x1[..., :, None, :]
68-
x2_ = x2[..., None, :, :]
69-
else:
70-
x1_ = KEOLazyTensor(x1[..., :, None, :])
71-
x2_ = KEOLazyTensor(x2[..., None, :, :])
72-
73-
distance = ((x1_ - x2_) ** 2).sum(-1).sqrt()
74-
exp_component = (-math.sqrt(self.nu * 2) * distance).exp()
75-
76-
if self.nu == 0.5:
77-
constant_component = 1
78-
elif self.nu == 1.5:
79-
constant_component = (math.sqrt(3) * distance) + 1
80-
elif self.nu == 2.5:
81-
constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * distance**2)
82-
83-
return constant_component * exp_component
84-
85-
def forward(self, x1, x2, diag=False, **params):
82+
def _keops_forward(self, x1, x2, **kwargs):
8683
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]
87-
88-
x1_ = (x1 - mean).div(self.lengthscale)
89-
x2_ = (x2 - mean).div(self.lengthscale)
90-
91-
if diag:
92-
return self.covar_func(x1_, x2_, diag=True)
93-
94-
covar_func = lambda x1, x2, diag=False: self.covar_func(x1, x2, diag)
95-
return KeOpsLinearOperator(x1_, x2_, covar_func)
84+
x1_ = (x1 - mean) / self.lengthscale
85+
x2_ = (x2 - mean) / self.lengthscale
86+
# return KernelLinearOperator inst only when calculating the whole covariance matrix
87+
return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs)
9688

9789
except ImportError:
9890

99-
class MaternKernel(KeOpsKernel):
100-
def __init__(self, *args, **kwargs):
101-
super().__init__()
91+
class MaternKernel(GMaternKernel):
92+
pass
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#!/usr/bin/env python3
2+
3+
import math
4+
5+
from linear_operator.operators import KernelLinearOperator
6+
7+
from ..periodic_kernel import PeriodicKernel as GPeriodicKernel
8+
from .keops_kernel import KeOpsKernel
9+
10+
# from ...kernels import PeriodicKernel gives a cyclic import
11+
12+
try:
13+
from pykeops.torch import LazyTensor as KEOLazyTensor
14+
15+
def _covar_func(x1, x2, lengthscale, **kwargs):
16+
# symbolic array of shape ..., ndatax1_ x 1 x ndim
17+
x1_ = KEOLazyTensor(x1[..., :, None, :])
18+
# symbolic array of shape ..., 1 x ndatax2_ x ndim
19+
x2_ = KEOLazyTensor(x2[..., None, :, :])
20+
lengthscale = lengthscale[..., None, None, 0, :] # 1 x 1 x ndim
21+
# do not use .power(2.0) as it gives NaN values on cuda
22+
# seems related to https://github.com/getkeops/keops/issues/112
23+
K = ((((x1_ - x2_).abs().sin()) ** 2) * (-2.0 / lengthscale)).sum(-1).exp()
24+
return K
25+
26+
# subclass from original periodic kernel to reduce code duplication
27+
class PeriodicKernel(KeOpsKernel, GPeriodicKernel):
28+
"""
29+
Implements the Periodic Kernel using KeOps as a driver for kernel matrix multiplies.
30+
31+
This class can be used as a drop in replacement for :class:`gpytorch.kernels.PeriodicKernel` in most cases,
32+
and supports the same arguments.
33+
34+
:param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each
35+
input dimension. It should be `d` if x1 is a `... x n x d` matrix.
36+
:type ard_num_dims: int, optional
37+
:param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each
38+
batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output.
39+
:type batch_shape: torch.Size, optional
40+
:param active_dims: (Default: `None`) Set this if you want to
41+
compute the covariance of only a few input dimensions. The ints
42+
corresponds to the indices of the dimensions.
43+
:type active_dims: Tuple(int)
44+
:param period_length_prior: (Default: `None`)
45+
Set this if you want to apply a prior to the period length parameter.
46+
:type period_length_prior: ~gpytorch.priors.Prior, optional
47+
:param period_length_constraint: (Default: `Positive`) Set this if you want
48+
to apply a constraint to the period length parameter.
49+
:type period_length_constraint: ~gpytorch.constraints.Interval, optional
50+
:param lengthscale_prior: (Default: `None`)
51+
Set this if you want to apply a prior to the lengthscale parameter.
52+
:type lengthscale_prior: ~gpytorch.priors.Prior, optional
53+
:param lengthscale_constraint: (Default: `Positive`) Set this if you want
54+
to apply a constraint to the lengthscale parameter.
55+
:type lengthscale_constraint: ~gpytorch.constraints.Interval, optional
56+
:param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors).
57+
:type eps: float, optional
58+
59+
:var torch.Tensor period_length: The period length parameter. Size/shape of parameter depends on the
60+
ard_num_dims and batch_shape arguments.
61+
"""
62+
63+
has_lengthscale = True
64+
65+
# code from the already-implemented Periodic Kernel
66+
def _nonkeops_forward(self, x1, x2, diag=False, **kwargs):
67+
x1_ = x1.div(self.period_length / math.pi)
68+
x2_ = x2.div(self.period_length / math.pi)
69+
70+
# We are automatically overriding last_dim_is_batch here so that we can manually sum over dimensions.
71+
diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True)
72+
73+
if diag:
74+
lengthscale = self.lengthscale[..., 0, :, None]
75+
else:
76+
lengthscale = self.lengthscale[..., 0, :, None, None]
77+
78+
exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0)
79+
exp_term = exp_term.sum(dim=(-2 if diag else -3))
80+
81+
return exp_term.exp()
82+
83+
def _keops_forward(self, x1, x2, **kwargs):
84+
x1_ = x1.div(self.period_length / math.pi)
85+
x2_ = x2.div(self.period_length / math.pi)
86+
# return KernelLinearOperator inst only when calculating the whole covariance matrix
87+
# pass any parameters which are used inside _covar_func as *args to get gradients computed for them
88+
return KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs)
89+
90+
except ImportError:
91+
92+
class PeriodicKernel(GPeriodicKernel):
93+
pass

0 commit comments

Comments
 (0)