22import math
33
44import torch
5- from linear_operator .operators import KeOpsLinearOperator
5+ from linear_operator .operators import KernelLinearOperator
66
7- from ... import settings
7+ from ..matern_kernel import MaternKernel as GMaternKernel
88from .keops_kernel import KeOpsKernel
99
1010try :
1111 from pykeops .torch import LazyTensor as KEOLazyTensor
1212
13+ def _covar_func (x1 , x2 , nu = 2.5 , ** params ):
14+ x1_ = KEOLazyTensor (x1 [..., :, None , :])
15+ x2_ = KEOLazyTensor (x2 [..., None , :, :])
16+
17+ distance = ((x1_ - x2_ ) ** 2 ).sum (- 1 ).sqrt ()
18+ exp_component = (- math .sqrt (nu * 2 ) * distance ).exp ()
19+
20+ if nu == 0.5 :
21+ constant_component = 1
22+ elif nu == 1.5 :
23+ constant_component = (math .sqrt (3 ) * distance ) + 1
24+ elif nu == 2.5 :
25+ constant_component = (math .sqrt (5 ) * distance ) + (1 + 5.0 / 3.0 * (distance ** 2 ))
26+
27+ return constant_component * exp_component
28+
1329 class MaternKernel (KeOpsKernel ):
1430 """
1531 Implements the Matern kernel using KeOps as a driver for kernel matrix multiplies.
1632
17- This class can be used as a drop in replacement for gpytorch.kernels.MaternKernel in most cases, and supports
18- the same arguments. There are currently a few limitations, for example a lack of batch mode support. However,
19- most other features like ARD will work.
33+ This class can be used as a drop in replacement for :class:`gpytorch.kernels.MaternKernel` in most cases,
34+ and supports the same arguments.
35+
36+ :param nu: (Default: 2.5) The smoothness parameter.
37+ :type nu: float (0.5, 1.5, or 2.5)
38+ :param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each
39+ input dimension. It should be `d` if x1 is a `... x n x d` matrix.
40+ :type ard_num_dims: int, optional
41+ :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each
42+ batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output.
43+ :type batch_shape: torch.Size, optional
44+ :param active_dims: (Default: `None`) Set this if you want to
45+ compute the covariance of only a few input dimensions. The ints
46+ corresponds to the indices of the dimensions.
47+ :type active_dims: Tuple(int)
48+ :param lengthscale_prior: (Default: `None`)
49+ Set this if you want to apply a prior to the lengthscale parameter.
50+ :type lengthscale_prior: ~gpytorch.priors.Prior, optional
51+ :param lengthscale_constraint: (Default: `Positive`) Set this if you want
52+ to apply a constraint to the lengthscale parameter.
53+ :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional
54+ :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors).
55+ :type eps: float, optional
2056 """
2157
2258 has_lengthscale = True
@@ -27,8 +63,12 @@ def __init__(self, nu=2.5, **kwargs):
2763 super (MaternKernel , self ).__init__ (** kwargs )
2864 self .nu = nu
2965
30- def _nonkeops_covar_func (self , x1 , x2 , diag = False ):
31- distance = self .covar_dist (x1 , x2 , diag = diag )
66+ def _nonkeops_forward (self , x1 , x2 , diag = False , ** kwargs ):
67+ mean = x1 .reshape (- 1 , x1 .size (- 1 )).mean (0 )[(None ,) * (x1 .dim () - 1 )]
68+ x1_ = (x1 - mean ) / self .lengthscale
69+ x2_ = (x2 - mean ) / self .lengthscale
70+
71+ distance = self .covar_dist (x1_ , x2_ , diag = diag , ** kwargs )
3272 exp_component = torch .exp (- math .sqrt (self .nu * 2 ) * distance )
3373
3474 if self .nu == 0.5 :
@@ -39,63 +79,14 @@ def _nonkeops_covar_func(self, x1, x2, diag=False):
3979 constant_component = (math .sqrt (5 ) * distance ).add (1 ).add (5.0 / 3.0 * distance ** 2 )
4080 return constant_component * exp_component
4181
42- def covar_func (self , x1 , x2 , diag = False ):
43- # We only should use KeOps on big kernel matrices
44- # If we would otherwise be performing Cholesky inference, (or when just computing a kernel matrix diag)
45- # then don't apply KeOps
46- # enable gradients to ensure that test time caches on small predictions are still
47- # backprop-able
48- with torch .autograd .enable_grad ():
49- if (
50- diag
51- or x1 .size (- 2 ) < settings .max_cholesky_size .value ()
52- or x2 .size (- 2 ) < settings .max_cholesky_size .value ()
53- ):
54- return self ._nonkeops_covar_func (x1 , x2 , diag = diag )
55- # TODO: x1 / x2 size checks are a work around for a very minor bug in KeOps.
56- # This bug is fixed on KeOps master, and we'll remove that part of the check
57- # when they cut a new release.
58- elif x1 .size (- 2 ) == 1 or x2 .size (- 2 ) == 1 :
59- return self ._nonkeops_covar_func (x1 , x2 , diag = diag )
60- else :
61- # We only should use KeOps on big kernel matrices
62- # If we would otherwise be performing Cholesky inference, then don't apply KeOps
63- if (
64- x1 .size (- 2 ) < settings .max_cholesky_size .value ()
65- or x2 .size (- 2 ) < settings .max_cholesky_size .value ()
66- ):
67- x1_ = x1 [..., :, None , :]
68- x2_ = x2 [..., None , :, :]
69- else :
70- x1_ = KEOLazyTensor (x1 [..., :, None , :])
71- x2_ = KEOLazyTensor (x2 [..., None , :, :])
72-
73- distance = ((x1_ - x2_ ) ** 2 ).sum (- 1 ).sqrt ()
74- exp_component = (- math .sqrt (self .nu * 2 ) * distance ).exp ()
75-
76- if self .nu == 0.5 :
77- constant_component = 1
78- elif self .nu == 1.5 :
79- constant_component = (math .sqrt (3 ) * distance ) + 1
80- elif self .nu == 2.5 :
81- constant_component = (math .sqrt (5 ) * distance ) + (1 + 5.0 / 3.0 * distance ** 2 )
82-
83- return constant_component * exp_component
84-
85- def forward (self , x1 , x2 , diag = False , ** params ):
82+ def _keops_forward (self , x1 , x2 , ** kwargs ):
8683 mean = x1 .reshape (- 1 , x1 .size (- 1 )).mean (0 )[(None ,) * (x1 .dim () - 1 )]
87-
88- x1_ = (x1 - mean ).div (self .lengthscale )
89- x2_ = (x2 - mean ).div (self .lengthscale )
90-
91- if diag :
92- return self .covar_func (x1_ , x2_ , diag = True )
93-
94- covar_func = lambda x1 , x2 , diag = False : self .covar_func (x1 , x2 , diag )
95- return KeOpsLinearOperator (x1_ , x2_ , covar_func )
84+ x1_ = (x1 - mean ) / self .lengthscale
85+ x2_ = (x2 - mean ) / self .lengthscale
86+ # return KernelLinearOperator inst only when calculating the whole covariance matrix
87+ return KernelLinearOperator (x1_ , x2_ , covar_func = _covar_func , nu = self .nu , ** kwargs )
9688
9789except ImportError :
9890
99- class MaternKernel (KeOpsKernel ):
100- def __init__ (self , * args , ** kwargs ):
101- super ().__init__ ()
91+ class MaternKernel (GMaternKernel ):
92+ pass
0 commit comments