Skip to content

Commit 98dd616

Browse files
Merge pull request #1717 from AustinT/sgpr-diag-correction
Add setting to turn off diagonal correction for InducingPointKernel.
2 parents 3e9bdb2 + ba238fb commit 98dd616

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

gpytorch/kernels/inducing_point_kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77

8+
from .. import settings
89
from ..distributions import MultivariateNormal
910
from ..lazy import DiagLazyTensor, LowRankRootAddedDiagLazyTensor, LowRankRootLazyTensor, MatmulLazyTensor, delazify
1011
from ..mlls import InducingPointKernelAddedLossTerm
@@ -61,7 +62,7 @@ def _get_covariance(self, x1, x2):
6162
covar = LowRankRootLazyTensor(k_ux1.matmul(self._inducing_inv_root))
6263

6364
# Diagonal correction for predictive posterior
64-
if not self.training:
65+
if not self.training and settings.sgpr_diagonal_correction.on():
6566
correction = (self.base_kernel(x1, x2, diag=True) - covar.diag()).clamp(0, math.inf)
6667
covar = LowRankRootAddedDiagLazyTensor(covar, DiagLazyTensor(correction))
6768
else:

gpytorch/settings.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,19 @@ class prior_mode(_feature_flag):
640640
_default = False
641641

642642

643+
class sgpr_diagonal_correction(_feature_flag):
644+
"""
645+
If set to true, during posterior prediction the variances of the InducingPointKernel
646+
will be corrected to match the variances of the exact kernel.
647+
648+
If false then no such correction will be performed (this is the default in other libraries).
649+
650+
(Default: True)
651+
"""
652+
653+
_default = True
654+
655+
643656
class skip_logdet_forward(_feature_flag):
644657
"""
645658
.. warning:

test/kernels/test_inducing_point_kernel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,14 @@ def test_kernel_output(self):
4848
output = model.likelihood(model(test_x))
4949
_ = output.mean + output.variance # Compute something to break through any lazy evaluations
5050
self.assertTrue(ps_mock.called)
51+
52+
# Check whether changing diagonal correction makes a difference (ensuring that cache is cleared)
53+
model.train()
54+
model.eval()
55+
with gpytorch.settings.sgpr_diagonal_correction(True), torch.no_grad():
56+
output_mean_correct = model(test_x).mean
57+
model.train()
58+
model.eval()
59+
with gpytorch.settings.sgpr_diagonal_correction(False), torch.no_grad():
60+
output_mean_no_correct = model(test_x).mean
61+
self.assertNotAlmostEqual(output_mean_correct.sum().item(), output_mean_no_correct.sum().item())

0 commit comments

Comments
 (0)