Skip to content

Commit 577693e

Browse files
committed
Remove evaluate kernel
1 parent 305c1ea commit 577693e

23 files changed

+249
-81
lines changed

gpytorch/kernels/additive_structure_kernel.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@ class AdditiveStructureKernel(Kernel):
3434
Passed down to the `base_kernel`.
3535
"""
3636

37-
@property
38-
def is_stationary(self) -> bool:
39-
"""
40-
Kernel is stationary if the base kernel is stationary.
41-
"""
42-
return self.base_kernel.is_stationary
43-
4437
def __init__(
4538
self,
4639
base_kernel: Kernel,
@@ -51,6 +44,17 @@ def __init__(
5144
self.base_kernel = base_kernel
5245
self.num_dims = num_dims
5346

47+
@property
48+
def _lazily_evaluate(self) -> bool:
49+
return self.base_kernel._lazily_evaluate
50+
51+
@property
52+
def is_stationary(self) -> bool:
53+
"""
54+
Kernel is stationary if the base kernel is stationary.
55+
"""
56+
return self.base_kernel.is_stationary
57+
5458
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
5559
if last_dim_is_batch:
5660
raise RuntimeError("AdditiveStructureKernel does not accept the last_dim_is_batch argument.")

gpytorch/kernels/cosine_kernel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ class CosineKernel(Kernel):
5656
>>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10)
5757
"""
5858

59-
is_stationary = True
60-
6159
def __init__(
6260
self,
6361
period_length_prior: Optional[Prior] = None,
@@ -85,6 +83,10 @@ def __init__(
8583

8684
self.register_constraint("raw_period_length", period_length_constraint)
8785

86+
@property
87+
def is_stationary(self):
88+
return True
89+
8890
@property
8991
def period_length(self):
9092
return self.raw_period_length_constraint.transform(self.raw_period_length)

gpytorch/kernels/cylindrical_kernel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66

7-
from .. import settings
87
from ..constraints import Interval, Positive
98
from ..priors import Prior
109
from .kernel import Kernel
@@ -152,8 +151,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal
152151
else:
153152
angular_kernel = angular_kernel + self.angular_weights[..., p, None].mul(gram_mat.pow(p))
154153

155-
with settings.lazily_evaluate_kernels(False):
156-
radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params)
154+
radial_kernel = self.radial_base_kernel.forward(self.kuma(r1), self.kuma(r2), diag=diag, **params)
157155
return radial_kernel.mul(angular_kernel)
158156

159157
def kuma(self, x: torch.Tensor) -> torch.Tensor:

gpytorch/kernels/grid_interpolation_kernel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ def __init__(
121121
)
122122
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))
123123

124+
@property
125+
def _lazily_evaluate(self) -> bool:
126+
# GridInterpolationKernels should not lazily evaluate; there are few gains (the inducing point kernel
127+
# matrix always needs to be evaluated; regardless of the size of x1 and x2), and the
128+
# InterpolatedLinearOperator structure is needed for fast predictions.
129+
return False
130+
124131
@property
125132
def _tight_grid_bounds(self):
126133
grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds))

gpytorch/kernels/grid_kernel.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ class GridKernel(Kernel):
4444
http://www.cs.cmu.edu/~andrewgw/manet.pdf
4545
"""
4646

47-
is_stationary = True
48-
4947
def __init__(
5048
self,
5149
base_kernel: Kernel,
@@ -66,6 +64,15 @@ def __init__(
6664
if not self.interpolation_mode:
6765
self.register_buffer("full_grid", create_data_from_grid(grid))
6866

67+
@property
68+
def _lazily_evaluate(self) -> bool:
69+
# Toeplitz structure is very efficient; no need to lazily evaluate
70+
return False
71+
72+
@property
73+
def is_stationary(self) -> bool:
74+
return True
75+
6976
def _clear_cache(self):
7077
if hasattr(self, "_cached_kernel_mat"):
7178
del self._cached_kernel_mat

gpytorch/kernels/index_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def __init__(
7676

7777
self.register_constraint("raw_var", var_constraint)
7878

79+
@property
80+
def _lazily_evaluate(self) -> bool:
81+
# IndexKernel does not need lazy evaluation, since the complete BB^T + D_v` is always
82+
# computed regardless of x1 and x2
83+
return False
84+
7985
@property
8086
def var(self):
8187
return self.raw_var_constraint.transform(self.raw_var)

gpytorch/kernels/inducing_point_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def _clear_cache(self):
4747
if hasattr(self, "_cached_kernel_inv_root"):
4848
del self._cached_kernel_inv_root
4949

50+
@property
51+
def _lazily_evaluate(self) -> bool:
52+
# InducingPointKernels kernels should not lazily evaluate; to use the Woodbury formula,
53+
# we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator.
54+
return False
55+
5056
@property
5157
def _inducing_mat(self):
5258
if not self.training and hasattr(self, "_cached_kernel_mat"):

gpytorch/kernels/keops/matern_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33

44
import torch
5-
from linear_operator.operators import KeOpsLinearOperator
5+
from linear_operator.operators import KernelLinearOperator
66

77
from ... import settings
88
from .keops_kernel import KeOpsKernel
@@ -92,7 +92,7 @@ def forward(self, x1, x2, diag=False, **params):
9292
return self.covar_func(x1_, x2_, diag=True)
9393

9494
covar_func = lambda x1, x2, diag=False: self.covar_func(x1, x2, diag)
95-
return KeOpsLinearOperator(x1_, x2_, covar_func)
95+
return KernelLinearOperator(x1_, x2_, covar_func=covar_func)
9696

9797
except ImportError:
9898

gpytorch/kernels/keops/rbf_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
import torch
4-
from linear_operator.operators import KeOpsLinearOperator
4+
from linear_operator.operators import KernelLinearOperator
55

66
from ... import settings
77
from ..rbf_kernel import postprocess_rbf
@@ -54,7 +54,7 @@ def forward(self, x1, x2, diag=False, **params):
5454
if diag:
5555
return covar_func(x1_, x2_, diag=True)
5656

57-
return KeOpsLinearOperator(x1_, x2_, covar_func)
57+
return KernelLinearOperator(x1_, x2_, covar_func=covar_func)
5858

5959
except ImportError:
6060

gpytorch/kernels/kernel.py

Lines changed: 123 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
import warnings
66
from abc import abstractmethod
7+
from collections import defaultdict, OrderedDict
78
from copy import deepcopy
8-
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
9+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
910

1011
import torch
1112
from linear_operator import to_dense, to_linear_operator
12-
from linear_operator.operators import LinearOperator, ZeroLinearOperator
13+
from linear_operator.operators import KernelLinearOperator, LinearOperator, ZeroLinearOperator
1314
from torch import Tensor
1415
from torch.nn import ModuleList
1516

@@ -75,6 +76,45 @@ def _dist(self, x1, x2, x1_eq_x2=False, postprocess=False):
7576
return self._postprocess(res) if postprocess else res
7677

7778

79+
class _autograd_kernel_hack(object):
80+
"""
81+
Helper class.
82+
83+
When using KernelLinearOperator, the `covar_func` cannot close over any Tensors that require gradients.
84+
(Any Tensor that `covar_func` closes over will not backpropagate gradients.)
85+
Unfortunately, for most kernels, `covar_func=self.forward`, which closes over all of the kernel's parameters.
86+
87+
This context manager temporarily replaces a kernel (and its submodules') parameter assignments with an
88+
external set of references to these parameters.
89+
The external set of references will be passed in by KernelLinearOperator.
90+
91+
This way, when calling self.forward, no parameter references are closed over, and so all parameters
92+
will receive the appropriate gradients.
93+
"""
94+
95+
def __init__(self, kernel: Kernel, params: Iterable[torch.nn.Parameters], param_names: Iterable[str]):
96+
self.temp_module_param_dicts = defaultdict(OrderedDict)
97+
for name, param in zip(param_names, params):
98+
split_name = name.split(".")
99+
module = kernel
100+
while len(split_name) > 1:
101+
module_name, *remaining_names = split_name
102+
module = getattr(module, module_name)
103+
split_name = remaining_names
104+
(base_param_name,) = split_name
105+
self.temp_module_param_dicts[module][base_param_name] = param
106+
107+
self.orig_model_param_dicts = dict((module, module._parameters) for module in self.temp_module_param_dicts)
108+
109+
def __enter__(self):
110+
for module, temp_param_dict in self.temp_module_param_dicts.items():
111+
object.__setattr__(module, "_parameters", temp_param_dict)
112+
113+
def __exit__(self, type, value, traceback):
114+
for module, orig_param_dict in self.orig_model_param_dicts.items():
115+
object.__setattr__(module, "_parameters", orig_param_dict)
116+
117+
78118
class Kernel(Module):
79119
r"""
80120
Kernels in GPyTorch are implemented as a :class:`gpytorch.Module` that, when called on two :class:`torch.Tensor`
@@ -206,6 +246,37 @@ def __init__(
206246
# TODO: Remove this on next official PyTorch release.
207247
self.__pdist_supports_batch = True
208248

249+
@property
250+
def _lazily_evaluate(self) -> bool:
251+
r"""
252+
Determines whether or not the kernel is lazily evaluated.
253+
254+
If False, kernel(x1, x2) produces a Tensor/LinearOperator where the covariance function has been evaluated
255+
over x1 and x2.
256+
257+
If True, kernel(x1, x2) produces a KernelLinearOperator that delays evaluation of the kernel function.
258+
The kernel function will only be evaluated when either
259+
- An mathematical operation is performed on the kernel matrix (e.g. solves, logdets, etc.), or
260+
- An indexing operation is performed on the kernel matrix to select specific covariance entries.
261+
262+
In general, _lazily_evaluate should return True (this option is more efficient), unless lazy evaluation
263+
offers no gains and there is specific structure that will be lost with lazy evaluation
264+
(e.g. low-rank/Nystrom approximations).
265+
"""
266+
return True
267+
268+
def _kernel_linear_operator_covar_func(
269+
self, x1: Tensor, x2: Tensor, *params: torch.nn.Parameter, param_names: Dict[str] = {}, **kwargs: Any
270+
) -> Union[Tensor, LinearOperator]:
271+
# This is the `covar_function` that is passed into KernelLinearOperator
272+
# This function calls self.forward, but does so in a way so that no parameters are closed over
273+
# (by using the _autograd_kernel_hack context manager)
274+
if any(param.requires_grad for param in params):
275+
with _autograd_kernel_hack(self, params, param_names):
276+
return self.forward(x1, x2, **kwargs)
277+
else:
278+
return self.forward(x1, x2, **kwargs)
279+
209280
def _lengthscale_param(self, m: Kernel) -> Tensor:
210281
# Used by the lengthscale_prior
211282
return m.lengthscale
@@ -451,7 +522,7 @@ def sub_kernels(self) -> Iterable[Kernel]:
451522
yield kernel
452523

453524
def __call__(
454-
self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, last_dim_is_batch: bool = False, **params
525+
self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, last_dim_is_batch: bool = False, **kwargs
455526
) -> Union[LazyEvaluatedKernelTensor, LinearOperator, Tensor]:
456527
r"""
457528
Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
@@ -508,7 +579,7 @@ def __call__(
508579
)
509580

510581
if diag:
511-
res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, **params)
582+
res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, **kwargs)
512583
# Did this Kernel eat the diag option?
513584
# If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output
514585
if not isinstance(res, LazyEvaluatedKernelTensor):
@@ -517,11 +588,42 @@ def __call__(
517588
return res
518589

519590
else:
520-
if settings.lazily_evaluate_kernels.on():
521-
res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params)
591+
if (settings.lazily_evaluate_kernels.on() and self._lazily_evaluate) or last_dim_is_batch:
592+
num_outputs_per_input = self.num_outputs_per_input(x1_, x2_)
593+
named_parameters = tuple(self.named_parameters())
594+
595+
if last_dim_is_batch:
596+
x1_ = x1_.transpose(-1, -2).unsqueeze(-1)
597+
x2_ = x2_.transpose(-1, -2).unsqueeze(-1)
598+
599+
if len(named_parameters):
600+
param_names, params = zip(*named_parameters)
601+
param_batch_shapes = [self.batch_shape] * len(params)
602+
if last_dim_is_batch:
603+
params = [
604+
param.unsqueeze(len(param_batch_shape)).transpose(-1, len(param_batch_shape))
605+
for param, param_batch_shape in zip(params, param_batch_shapes)
606+
]
607+
param_batch_shapes = [
608+
torch.Size([*param_batch_shape, x1_.size(-3)]) for param_batch_shape in param_batch_shapes
609+
]
610+
res = KernelLinearOperator(
611+
x1_,
612+
x2_,
613+
*params,
614+
covar_func=self._kernel_linear_operator_covar_func,
615+
num_outputs_per_input=num_outputs_per_input,
616+
param_batch_shapes=param_batch_shapes,
617+
param_names=param_names,
618+
**kwargs,
619+
)
620+
else:
621+
res = KernelLinearOperator(
622+
x1_, x2_, covar_func=self.forward, num_outputs_per_input=num_outputs_per_input, **kwargs
623+
)
522624
else:
523625
res = to_linear_operator(
524-
super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
626+
super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **kwargs)
525627
)
526628
return res
527629

@@ -593,13 +695,17 @@ class AdditiveKernel(Kernel):
593695
:param kernels: Kernels to add together.
594696
"""
595697

698+
def __init__(self, *kernels: Iterable[Kernel]):
699+
super(AdditiveKernel, self).__init__()
700+
self.kernels = ModuleList(kernels)
701+
596702
@property
597703
def is_stationary(self) -> bool:
598704
return all(k.is_stationary for k in self.kernels)
599705

600-
def __init__(self, *kernels: Iterable[Kernel]):
601-
super(AdditiveKernel, self).__init__()
602-
self.kernels = ModuleList(kernels)
706+
@property
707+
def _lazily_evaluate(self) -> bool:
708+
return all(k._lazily_evaluate for k in self.kernels)
603709

604710
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
605711
res = ZeroLinearOperator() if not diag else 0
@@ -635,13 +741,17 @@ class ProductKernel(Kernel):
635741
:param kernels: Kernels to multiply together.
636742
"""
637743

744+
def __init__(self, *kernels: Iterable[Kernel]):
745+
super(ProductKernel, self).__init__()
746+
self.kernels = ModuleList(kernels)
747+
638748
@property
639749
def is_stationary(self) -> bool:
640750
return all(k.is_stationary for k in self.kernels)
641751

642-
def __init__(self, *kernels: Iterable[Kernel]):
643-
super(ProductKernel, self).__init__()
644-
self.kernels = ModuleList(kernels)
752+
@property
753+
def _lazily_evaluate(self) -> bool:
754+
return False
645755

646756
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
647757
x1_eq_x2 = torch.equal(x1, x2)

0 commit comments

Comments
 (0)