Skip to content

Commit 19e67f3

Browse files
authored
Merge pull request #1786 from wjmaddox/kp_logdet
Add _logdet and _solve to KroneckerProductLazyTensor
2 parents 452443a + 8896ad7 commit 19e67f3

File tree

3 files changed

+39
-24
lines changed

3 files changed

+39
-24
lines changed

gpytorch/lazy/kronecker_product_lazy_tensor.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,15 @@ def inverse(self):
150150
inverses = [lt.inverse() for lt in self.lazy_tensors]
151151
return self.__class__(*inverses)
152152

153-
def inv_matmul(self, right_tensor, left_tensor=None):
154-
# TODO: Investigate under what conditions computing individual inverses makes sense
155-
# For now, retain existing behavior
156-
return super().inv_matmul(right_tensor=right_tensor, left_tensor=left_tensor)
153+
def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
154+
if inv_quad_rhs is not None:
155+
inv_quad_term, _ = super().inv_quad_logdet(
156+
inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad
157+
)
158+
else:
159+
inv_quad_term = None
160+
logdet_term = self._logdet() if logdet else None
161+
return inv_quad_term, logdet_term
157162

158163
@cached(name="cholesky")
159164
def _cholesky(self, upper=False):
@@ -183,22 +188,47 @@ def _get_indices(self, row_index, col_index, *batch_indices):
183188

184189
return res
185190

186-
def _inv_matmul(self, right_tensor, left_tensor=None):
191+
def _solve(self, rhs, preconditioner=None, num_tridiag=0):
187192
# Computes inv_matmul by exploiting the identity (A \kron B)^-1 = A^-1 \kron B^-1
193+
# we perform the solve first before worrying about any tridiagonal matrices
194+
188195
tsr_shapes = [q.size(-1) for q in self.lazy_tensors]
189-
n_rows = right_tensor.size(-2)
190-
batch_shape = _mul_broadcast_shape(self.shape[:-2], right_tensor.shape[:-2])
196+
n_rows = rhs.size(-2)
197+
batch_shape = _mul_broadcast_shape(self.shape[:-2], rhs.shape[:-2])
191198
perm_batch = tuple(range(len(batch_shape)))
192-
y = right_tensor.clone().expand(*batch_shape, *right_tensor.shape[-2:])
199+
y = rhs.clone().expand(*batch_shape, *rhs.shape[-2:])
193200
for n, q in zip(tsr_shapes, self.lazy_tensors):
194201
# for KroneckerProductTriangularLazyTensor this inv_matmul is very cheap
195202
y = q.inv_matmul(y.reshape(*batch_shape, n, -1))
196203
y = y.reshape(*batch_shape, n, n_rows // n, -1).permute(*perm_batch, -2, -3, -1)
197204
res = y.reshape(*batch_shape, n_rows, -1)
205+
206+
if num_tridiag == 0:
207+
return res
208+
else:
209+
# we need to return the t mat, so we return the eigenvalues
210+
# in general, this should not be called because log determinant estimation
211+
# is closed form and is implemented in _logdet
212+
# TODO: make this more efficient
213+
evals, _ = self.diagonalization()
214+
evals_repeated = evals.unsqueeze(0).repeat(num_tridiag, *[1] * evals.ndim)
215+
lazy_evals = DiagLazyTensor(evals_repeated)
216+
batch_repeated_evals = lazy_evals.evaluate()
217+
return res, batch_repeated_evals
218+
219+
def _inv_matmul(self, right_tensor, left_tensor=None):
220+
# if _inv_matmul is called, we ignore the eigenvalue handling
221+
# this is efficient because of the structure of the lazy tensor
222+
res = self._solve(rhs=right_tensor)
198223
if left_tensor is not None:
199224
res = left_tensor @ res
200225
return res
201226

227+
def _logdet(self):
228+
evals, _ = self.diagonalization()
229+
logdet = evals.clamp(min=1e-7).log().sum(-1)
230+
return logdet
231+
202232
def _matmul(self, rhs):
203233
is_vec = rhs.ndimension() == 1
204234
if is_vec:

test/lazy/test_lazy_evaluated_kernel_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def test_half(self):
137137

138138
class TestLazyEvaluatedKernelTensorMultitaskBatch(TestLazyEvaluatedKernelTensorBatch):
139139
seed = 0
140+
skip_slq_tests = True # we skip these because of the kronecker structure
140141

141142
def create_lazy_tensor(self):
142143
kern = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(), num_tasks=3, rank=2)

test/lazy/test_sum_kronecker_lazy_tensor.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,3 @@ def evaluate_lazy_tensor(self, lazy_tensor):
4747
lazy_tensor.lazy_tensors[1].lazy_tensors[0].tensor, lazy_tensor.lazy_tensors[1].lazy_tensors[1].tensor
4848
)
4949
return res1 + res2
50-
51-
def test_inv_quad_logdet(self):
52-
# mock call cg here
53-
self.__class__.should_call_cg = True
54-
super().test_inv_quad_logdet()
55-
self.__class__.should_call_cg = False
56-
57-
def test_inv_quad_logdet_no_reduce(self):
58-
self.__class__.should_call_cg = True
59-
super().test_inv_quad_logdet_no_reduce()
60-
self.__class__.should_call_cg = False
61-
62-
def test_root_decomposition_cholesky(self):
63-
self.__class__.should_call_cg = True
64-
super().test_root_decomposition_cholesky()
65-
self.__class__.should_call_cg = False

0 commit comments

Comments
 (0)