Skip to content

Commit 8896ad7

Browse files
committed
Merge branch 'gpytorch_master' into kp_logdet
1 parent d42ce99 commit 8896ad7

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

gpytorch/lazy/kronecker_product_lazy_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def _get_indices(self, row_index, col_index, *batch_indices):
190190

191191
def _solve(self, rhs, preconditioner=None, num_tridiag=0):
192192
# Computes inv_matmul by exploiting the identity (A \kron B)^-1 = A^-1 \kron B^-1
193+
# we perform the solve first before worrying about any tridiagonal matrices
194+
193195
tsr_shapes = [q.size(-1) for q in self.lazy_tensors]
194196
n_rows = rhs.size(-2)
195197
batch_shape = _mul_broadcast_shape(self.shape[:-2], rhs.shape[:-2])
@@ -200,10 +202,13 @@ def _solve(self, rhs, preconditioner=None, num_tridiag=0):
200202
y = q.inv_matmul(y.reshape(*batch_shape, n, -1))
201203
y = y.reshape(*batch_shape, n, n_rows // n, -1).permute(*perm_batch, -2, -3, -1)
202204
res = y.reshape(*batch_shape, n_rows, -1)
205+
203206
if num_tridiag == 0:
204207
return res
205208
else:
206209
# we need to return the t mat, so we return the eigenvalues
210+
# in general, this should not be called because log determinant estimation
211+
# is closed form and is implemented in _logdet
207212
# TODO: make this more efficient
208213
evals, _ = self.diagonalization()
209214
evals_repeated = evals.unsqueeze(0).repeat(num_tridiag, *[1] * evals.ndim)
@@ -213,6 +218,7 @@ def _solve(self, rhs, preconditioner=None, num_tridiag=0):
213218

214219
def _inv_matmul(self, right_tensor, left_tensor=None):
215220
# if _inv_matmul is called, we ignore the eigenvalue handling
221+
# this is efficient because of the structure of the lazy tensor
216222
res = self._solve(rhs=right_tensor)
217223
if left_tensor is not None:
218224
res = left_tensor @ res

0 commit comments

Comments
 (0)