@@ -150,10 +150,15 @@ def inverse(self):
150150 inverses = [lt .inverse () for lt in self .lazy_tensors ]
151151 return self .__class__ (* inverses )
152152
153- def inv_matmul (self , right_tensor , left_tensor = None ):
154- # TODO: Investigate under what conditions computing individual inverses makes sense
155- # For now, retain existing behavior
156- return super ().inv_matmul (right_tensor = right_tensor , left_tensor = left_tensor )
153+ def inv_quad_logdet (self , inv_quad_rhs = None , logdet = False , reduce_inv_quad = True ):
154+ if inv_quad_rhs is not None :
155+ inv_quad_term , _ = super ().inv_quad_logdet (
156+ inv_quad_rhs = inv_quad_rhs , logdet = False , reduce_inv_quad = reduce_inv_quad
157+ )
158+ else :
159+ inv_quad_term = None
160+ logdet_term = self ._logdet () if logdet else None
161+ return inv_quad_term , logdet_term
157162
158163 @cached (name = "cholesky" )
159164 def _cholesky (self , upper = False ):
@@ -183,22 +188,47 @@ def _get_indices(self, row_index, col_index, *batch_indices):
183188
184189 return res
185190
186- def _inv_matmul (self , right_tensor , left_tensor = None ):
191+ def _solve (self , rhs , preconditioner = None , num_tridiag = 0 ):
187192 # Computes inv_matmul by exploiting the identity (A \kron B)^-1 = A^-1 \kron B^-1
193+ # we perform the solve first before worrying about any tridiagonal matrices
194+
188195 tsr_shapes = [q .size (- 1 ) for q in self .lazy_tensors ]
189- n_rows = right_tensor .size (- 2 )
190- batch_shape = _mul_broadcast_shape (self .shape [:- 2 ], right_tensor .shape [:- 2 ])
196+ n_rows = rhs .size (- 2 )
197+ batch_shape = _mul_broadcast_shape (self .shape [:- 2 ], rhs .shape [:- 2 ])
191198 perm_batch = tuple (range (len (batch_shape )))
192- y = right_tensor .clone ().expand (* batch_shape , * right_tensor .shape [- 2 :])
199+ y = rhs .clone ().expand (* batch_shape , * rhs .shape [- 2 :])
193200 for n , q in zip (tsr_shapes , self .lazy_tensors ):
194201 # for KroneckerProductTriangularLazyTensor this inv_matmul is very cheap
195202 y = q .inv_matmul (y .reshape (* batch_shape , n , - 1 ))
196203 y = y .reshape (* batch_shape , n , n_rows // n , - 1 ).permute (* perm_batch , - 2 , - 3 , - 1 )
197204 res = y .reshape (* batch_shape , n_rows , - 1 )
205+
206+ if num_tridiag == 0 :
207+ return res
208+ else :
209+ # we need to return the t mat, so we return the eigenvalues
210+ # in general, this should not be called because log determinant estimation
211+ # is closed form and is implemented in _logdet
212+ # TODO: make this more efficient
213+ evals , _ = self .diagonalization ()
214+ evals_repeated = evals .unsqueeze (0 ).repeat (num_tridiag , * [1 ] * evals .ndim )
215+ lazy_evals = DiagLazyTensor (evals_repeated )
216+ batch_repeated_evals = lazy_evals .evaluate ()
217+ return res , batch_repeated_evals
218+
219+ def _inv_matmul (self , right_tensor , left_tensor = None ):
220+ # if _inv_matmul is called, we ignore the eigenvalue handling
221+ # this is efficient because of the structure of the lazy tensor
222+ res = self ._solve (rhs = right_tensor )
198223 if left_tensor is not None :
199224 res = left_tensor @ res
200225 return res
201226
227+ def _logdet (self ):
228+ evals , _ = self .diagonalization ()
229+ logdet = evals .clamp (min = 1e-7 ).log ().sum (- 1 )
230+ return logdet
231+
202232 def _matmul (self , rhs ):
203233 is_vec = rhs .ndimension () == 1
204234 if is_vec :
0 commit comments