Skip to content

Commit 3b4c74d

Browse files
authored
Merge pull request #1674 from wjmaddox/sum_kronecker_iqldt
inv_quad_logdet method for sumKronecker
2 parents b237019 + 381ff30 commit 3b4c74d

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

gpytorch/lazy/sum_kronecker_lazy_tensor.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python3
22

3-
from .. import settings
43
from .kronecker_product_lazy_tensor import KroneckerProductLazyTensor
54
from .sum_lazy_tensor import SumLazyTensor
65

@@ -36,9 +35,6 @@ def _sum_formulation(self):
3635
return inv_root_times_lt1
3736

3837
def _solve(self, rhs, preconditioner=None, num_tridiag=0):
39-
if self.shape[-1] <= settings.max_cholesky_size.value():
40-
return super()._solve(rhs=rhs, preconditioner=preconditioner, num_tridiag=num_tridiag)
41-
4238
inner_mat = self._sum_formulation
4339
# root decomposition may not be trustworthy if it uses a different method than
4440
# root_inv_decomposition. so ensure that we call this locally
@@ -73,3 +69,19 @@ def _root_inv_decomposition(self, initial_vectors=None):
7369
inner_mat_root_inv = inner_mat.root_inv_decomposition().root
7470
inv_root = lt2_root_inv.matmul(inner_mat_root_inv)
7571
return inv_root
72+
73+
def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
74+
inv_quad_term = None
75+
logdet_term = None
76+
77+
if inv_quad_rhs is not None:
78+
solve = self.inv_matmul(inv_quad_rhs)
79+
inv_quad_term = (inv_quad_rhs * solve).sum(-2)
80+
81+
if inv_quad_term.numel() and reduce_inv_quad:
82+
inv_quad_term = inv_quad_term.sum(-1)
83+
84+
if logdet:
85+
logdet_term = self._logdet()
86+
87+
return inv_quad_term, logdet_term

gpytorch/likelihoods/multitask_gaussian_likelihood.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DiagLazyTensor,
1313
KroneckerProductDiagLazyTensor,
1414
KroneckerProductLazyTensor,
15+
LazyEvaluatedKernelTensor,
1516
RootLazyTensor,
1617
)
1718
from ..likelihoods import Likelihood, _GaussianLikelihoodBase
@@ -88,6 +89,10 @@ def marginal(self, function_dist, *params, **kwargs):
8889
"""
8990
mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix
9091

92+
# ensure that sumKroneckerLT is actually called
93+
if isinstance(covar, LazyEvaluatedKernelTensor):
94+
covar = covar.evaluate_kernel()
95+
9196
covar_kron_lt = self._shaped_noise_covar(mean.shape, add_noise=self.has_global_noise)
9297
covar = covar + covar_kron_lt
9398

@@ -126,7 +131,7 @@ def _shaped_noise_covar(self, shape, add_noise=True, *params, **kwargs):
126131

127132
def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal:
128133
noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag()
129-
noise = noise.view(*noise.shape[:-1], *function_samples.shape[-2:])
134+
noise = noise.reshape(*noise.shape[:-1], *function_samples.shape[-2:])
130135
return base_distributions.Independent(base_distributions.Normal(function_samples, noise.sqrt()), 1)
131136

132137

test/lazy/test_sum_kronecker_lazy_tensor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class TestSumKroneckerLazyTensor(LazyTensorTestCase, unittest.TestCase):
2222
seed = 0
2323
should_call_lanczos = True
2424
should_call_cg = False
25-
skip_slq_tests = True
25+
skip_slq_tests = False
2626

2727
def create_lazy_tensor(self):
2828
a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float)
@@ -47,3 +47,19 @@ def evaluate_lazy_tensor(self, lazy_tensor):
4747
lazy_tensor.lazy_tensors[1].lazy_tensors[0].tensor, lazy_tensor.lazy_tensors[1].lazy_tensors[1].tensor
4848
)
4949
return res1 + res2
50+
51+
def test_inv_quad_logdet(self):
52+
# mock call cg here
53+
self.__class__.should_call_cg = True
54+
super().test_inv_quad_logdet()
55+
self.__class__.should_call_cg = False
56+
57+
def test_inv_quad_logdet_no_reduce(self):
58+
self.__class__.should_call_cg = True
59+
super().test_inv_quad_logdet_no_reduce()
60+
self.__class__.should_call_cg = False
61+
62+
def test_root_decomposition_cholesky(self):
63+
self.__class__.should_call_cg = True
64+
super().test_root_decomposition_cholesky()
65+
self.__class__.should_call_cg = False

0 commit comments

Comments
 (0)