@@ -150,11 +150,6 @@ def inverse(self):
150150 inverses = [lt .inverse () for lt in self .lazy_tensors ]
151151 return self .__class__ (* inverses )
152152
153- def inv_matmul (self , right_tensor , left_tensor = None ):
154- # TODO: Investigate under what conditions computing individual inverses makes sense
155- # For now, retain existing behavior
156- return super ().inv_matmul (right_tensor = right_tensor , left_tensor = left_tensor )
157-
158153 def inv_quad_logdet (self , inv_quad_rhs = None , logdet = False , reduce_inv_quad = True ):
159154 if inv_quad_rhs is not None :
160155 inv_quad_term , _ = super ().inv_quad_logdet (
@@ -217,13 +212,13 @@ def _solve(self, rhs, preconditioner=None, num_tridiag=0):
217212 return res , batch_repeated_evals
218213
219214 def _inv_matmul (self , right_tensor , left_tensor = None ):
215+ # if _inv_matmul is called, we ignore the eigenvalue handling
220216 res = self ._solve (rhs = right_tensor )
221217 if left_tensor is not None :
222218 res = left_tensor @ res
223219 return res
224220
225221 def _logdet (self ):
226- # return sum([lt.logdet() * lt.shape[-1] for lt in self.lazy_tensors])
227222 evals , _ = self .diagonalization ()
228223 logdet = evals .clamp (min = 1e-7 ).log ().sum (- 1 )
229224 return logdet
0 commit comments