Skip to content

Commit f5d1503

Browse files
committed
inv_quad_logdet method for sumKronecker
1 parent b237019 commit f5d1503

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

gpytorch/lazy/sum_kronecker_lazy_tensor.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python3
22

3-
from .. import settings
43
from .kronecker_product_lazy_tensor import KroneckerProductLazyTensor
54
from .sum_lazy_tensor import SumLazyTensor
65

@@ -36,9 +35,6 @@ def _sum_formulation(self):
3635
return inv_root_times_lt1
3736

3837
def _solve(self, rhs, preconditioner=None, num_tridiag=0):
39-
if self.shape[-1] <= settings.max_cholesky_size.value():
40-
return super()._solve(rhs=rhs, preconditioner=preconditioner, num_tridiag=num_tridiag)
41-
4238
inner_mat = self._sum_formulation
4339
# root decomposition may not be trustworthy if it uses a different method than
4440
# root_inv_decomposition. so ensure that we call this locally
@@ -73,3 +69,19 @@ def _root_inv_decomposition(self, initial_vectors=None):
7369
inner_mat_root_inv = inner_mat.root_inv_decomposition().root
7470
inv_root = lt2_root_inv.matmul(inner_mat_root_inv)
7571
return inv_root
72+
73+
def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
74+
inv_quad_term = None
75+
logdet_term = None
76+
77+
if inv_quad_rhs is not None:
78+
solve = self._solve(inv_quad_rhs)
79+
inv_quad_term = (inv_quad_rhs * solve).sum(-2)
80+
81+
if inv_quad_term.numel() and reduce_inv_quad:
82+
inv_quad_term = inv_quad_term.sum(-1)
83+
84+
if logdet:
85+
logdet_term = self._logdet()
86+
87+
return inv_quad_term, logdet_term

test/lazy/test_sum_kronecker_lazy_tensor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class TestSumKroneckerLazyTensor(LazyTensorTestCase, unittest.TestCase):
2222
seed = 0
2323
should_call_lanczos = True
2424
should_call_cg = False
25-
skip_slq_tests = True
25+
skip_slq_tests = False
2626

2727
def create_lazy_tensor(self):
2828
a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float)
@@ -47,3 +47,19 @@ def evaluate_lazy_tensor(self, lazy_tensor):
4747
lazy_tensor.lazy_tensors[1].lazy_tensors[0].tensor, lazy_tensor.lazy_tensors[1].lazy_tensors[1].tensor
4848
)
4949
return res1 + res2
50+
51+
def test_inv_quad_logdet(self):
52+
# mock call cg here
53+
self.__class__.should_call_cg = True
54+
super().test_inv_quad_logdet()
55+
self.__class__.should_call_cg = False
56+
57+
def test_inv_quad_logdet_no_reduce(self):
58+
self.__class__.should_call_cg = True
59+
super().test_inv_quad_logdet_no_reduce()
60+
self.__class__.should_call_cg = False
61+
62+
def test_root_decomposition_cholesky(self):
63+
self.__class__.should_call_cg = True
64+
super().test_root_decomposition_cholesky()
65+
self.__class__.should_call_cg = False

0 commit comments

Comments
 (0)