Skip to content

Commit 6d58b66

Browse files
committed
WIP
1 parent 7965e3a commit 6d58b66

20 files changed

+261
-53
lines changed

gpytorch/kernels/additive_structure_kernel.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@ class AdditiveStructureKernel(Kernel):
3434
Passed down to the `base_kernel`.
3535
"""
3636

37-
@property
38-
def is_stationary(self) -> bool:
39-
"""
40-
Kernel is stationary if the base kernel is stationary.
41-
"""
42-
return self.base_kernel.is_stationary
43-
4437
def __init__(
4538
self,
4639
base_kernel: Kernel,
@@ -51,6 +44,17 @@ def __init__(
5144
self.base_kernel = base_kernel
5245
self.num_dims = num_dims
5346

47+
@property
48+
def _lazily_evaluate(self) -> bool:
49+
return self.base_kernel._lazily_evaluate
50+
51+
@property
52+
def is_stationary(self) -> bool:
53+
"""
54+
Kernel is stationary if the base kernel is stationary.
55+
"""
56+
return self.base_kernel.is_stationary
57+
5458
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
5559
if last_dim_is_batch:
5660
raise RuntimeError("AdditiveStructureKernel does not accept the last_dim_is_batch argument.")

gpytorch/kernels/cosine_kernel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ class CosineKernel(Kernel):
5656
>>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10)
5757
"""
5858

59-
is_stationary = True
60-
6159
def __init__(
6260
self,
6361
period_length_prior: Optional[Prior] = None,
@@ -85,6 +83,10 @@ def __init__(
8583

8684
self.register_constraint("raw_period_length", period_length_constraint)
8785

86+
@property
87+
def is_stationary(self):
88+
return True
89+
8890
@property
8991
def period_length(self):
9092
return self.raw_period_length_constraint.transform(self.raw_period_length)

gpytorch/kernels/cylindrical_kernel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66

7-
from .. import settings
87
from ..constraints import Interval, Positive
98
from ..priors import Prior
109
from .kernel import Kernel
@@ -152,8 +151,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal
152151
else:
153152
angular_kernel = angular_kernel + self.angular_weights[..., p, None].mul(gram_mat.pow(p))
154153

155-
with settings.lazily_evaluate_kernels(False):
156-
radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params)
154+
radial_kernel = self.radial_base_kernel.forward(self.kuma(r1), self.kuma(r2), diag=diag, **params)
157155
return radial_kernel.mul(angular_kernel)
158156

159157
def kuma(self, x: torch.Tensor) -> torch.Tensor:

gpytorch/kernels/grid_interpolation_kernel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ def __init__(
121121
)
122122
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))
123123

124+
@property
125+
def _lazily_evaluate(self) -> bool:
126+
# GridInterpolationKernels should not lazily evaluate; there are few gains (the inducing point kernel
127+
# matrix always needs to be evaluated; regardless of the size of x1 and x2), and the
128+
# InterpolatedLinearOperator structure is needed for fast predictions.
129+
return False
130+
124131
@property
125132
def _tight_grid_bounds(self):
126133
grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds))

gpytorch/kernels/grid_kernel.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ class GridKernel(Kernel):
4444
http://www.cs.cmu.edu/~andrewgw/manet.pdf
4545
"""
4646

47-
is_stationary = True
48-
4947
def __init__(
5048
self,
5149
base_kernel: Kernel,
@@ -66,6 +64,15 @@ def __init__(
6664
if not self.interpolation_mode:
6765
self.register_buffer("full_grid", create_data_from_grid(grid))
6866

67+
@property
68+
def _lazily_evaluate(self) -> bool:
69+
# Toeplitz structure is very efficient; no need to lazily evaluate
70+
return False
71+
72+
@property
73+
def is_stationary(self) -> bool:
74+
return True
75+
6976
def _clear_cache(self):
7077
if hasattr(self, "_cached_kernel_mat"):
7178
del self._cached_kernel_mat

gpytorch/kernels/index_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def __init__(
7676

7777
self.register_constraint("raw_var", var_constraint)
7878

79+
@property
80+
def _lazily_evaluate(self) -> bool:
81+
# IndexKernel does not need lazy evaluation, since the complete BB^T + D_v` is always
82+
# computed regardless of x1 and x2
83+
return False
84+
7985
@property
8086
def var(self):
8187
return self.raw_var_constraint.transform(self.raw_var)

gpytorch/kernels/inducing_point_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def _clear_cache(self):
4747
if hasattr(self, "_cached_kernel_inv_root"):
4848
del self._cached_kernel_inv_root
4949

50+
@property
51+
def _lazily_evaluate(self) -> bool:
52+
# InducingPointKernels kernels should not lazily evaluate; to use the Woodbury formula,
53+
# we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator.
54+
return False
55+
5056
@property
5157
def _inducing_mat(self):
5258
if not self.training and hasattr(self, "_cached_kernel_mat"):

gpytorch/kernels/keops/matern_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33

44
import torch
5-
from linear_operator.operators import KeOpsLinearOperator
5+
from linear_operator.operators import KernelLinearOperator
66

77
from ... import settings
88
from .keops_kernel import KeOpsKernel
@@ -92,7 +92,7 @@ def forward(self, x1, x2, diag=False, **params):
9292
return self.covar_func(x1_, x2_, diag=True)
9393

9494
covar_func = lambda x1, x2, diag=False: self.covar_func(x1, x2, diag)
95-
return KeOpsLinearOperator(x1_, x2_, covar_func)
95+
return KernelLinearOperator(x1_, x2_, covar_func=covar_func)
9696

9797
except ImportError:
9898

gpytorch/kernels/keops/rbf_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
import torch
4-
from linear_operator.operators import KeOpsLinearOperator
4+
from linear_operator.operators import KernelLinearOperator
55

66
from ... import settings
77
from ..rbf_kernel import postprocess_rbf
@@ -54,7 +54,7 @@ def forward(self, x1, x2, diag=False, **params):
5454
if diag:
5555
return covar_func(x1_, x2_, diag=True)
5656

57-
return KeOpsLinearOperator(x1_, x2_, covar_func)
57+
return KernelLinearOperator(x1_, x2_, covar_func=covar_func)
5858

5959
except ImportError:
6060

0 commit comments

Comments
 (0)