@@ -190,6 +190,8 @@ def _get_indices(self, row_index, col_index, *batch_indices):
190190
191191 def _solve (self , rhs , preconditioner = None , num_tridiag = 0 ):
192192 # Computes inv_matmul by exploiting the identity (A \kron B)^-1 = A^-1 \kron B^-1
193+ # we perform the solve first before worrying about any tridiagonal matrices
194+
193195 tsr_shapes = [q .size (- 1 ) for q in self .lazy_tensors ]
194196 n_rows = rhs .size (- 2 )
195197 batch_shape = _mul_broadcast_shape (self .shape [:- 2 ], rhs .shape [:- 2 ])
@@ -200,10 +202,13 @@ def _solve(self, rhs, preconditioner=None, num_tridiag=0):
200202 y = q .inv_matmul (y .reshape (* batch_shape , n , - 1 ))
201203 y = y .reshape (* batch_shape , n , n_rows // n , - 1 ).permute (* perm_batch , - 2 , - 3 , - 1 )
202204 res = y .reshape (* batch_shape , n_rows , - 1 )
205+
203206 if num_tridiag == 0 :
204207 return res
205208 else :
206209 # we need to return the t mat, so we return the eigenvalues
210+ # in general, this should not be called because log determinant estimation
211+ # is closed form and is implemented in _logdet
207212 # TODO: make this more efficient
208213 evals , _ = self .diagonalization ()
209214 evals_repeated = evals .unsqueeze (0 ).repeat (num_tridiag , * [1 ] * evals .ndim )
@@ -213,6 +218,7 @@ def _solve(self, rhs, preconditioner=None, num_tridiag=0):
213218
214219 def _inv_matmul (self , right_tensor , left_tensor = None ):
215220 # if _inv_matmul is called, we ignore the eigenvalue handling
221+ # this is efficient because of the structure of the lazy tensor
216222 res = self ._solve (rhs = right_tensor )
217223 if left_tensor is not None :
218224 res = left_tensor @ res
0 commit comments