88import torch
99
1010import linear_operator
11+ from linear_operator .operators import DiagLinearOperator , to_dense
1112from linear_operator .settings import linalg_dtypes
1213from linear_operator .utils .errors import CachingError
1314from linear_operator .utils .memoize import get_from_cache
@@ -34,14 +35,16 @@ def _test_matmul(self, rhs):
3435 linear_op = self .create_linear_op ().detach ().requires_grad_ (True )
3536 linear_op_copy = torch .clone (linear_op ).detach ().requires_grad_ (True )
3637 evaluated = self .evaluate_linear_op (linear_op_copy )
38+ rhs_evaluated = to_dense (rhs )
3739
3840 # Test operator
3941 res = linear_op @ rhs
40- actual = evaluated .matmul (rhs )
41- self .assertAllClose (res , actual )
42+ actual = evaluated .matmul (rhs_evaluated )
43+ res_evaluated = to_dense (res )
44+ self .assertAllClose (res_evaluated , actual )
4245
43- grad = torch .randn_like (res )
44- res .backward (gradient = grad )
46+ grad = torch .randn_like (res_evaluated )
47+ res_evaluated .backward (gradient = grad )
4548 actual .backward (gradient = grad )
4649 for arg , arg_copy in zip (linear_op .representation (), linear_op_copy .representation ()):
4750 if arg_copy .requires_grad and arg_copy .is_leaf and arg_copy .grad is not None :
@@ -50,7 +53,7 @@ def _test_matmul(self, rhs):
5053 # Test __torch_function__
5154 res = torch .matmul (linear_op , rhs )
5255 actual = evaluated .matmul (rhs )
53- self .assertAllClose (res , actual )
56+ self .assertAllClose (to_dense ( res ) , actual )
5457
5558 def _test_rmatmul (self , lhs ):
5659 linear_op = self .create_linear_op ().detach ().requires_grad_ (True )
@@ -305,6 +308,12 @@ def test_rmatmul_matrix(self):
305308 lhs = torch .randn (* linear_op .batch_shape , 4 , linear_op .size (- 2 ))
306309 return self ._test_rmatmul (lhs )
307310
311+ def test_matmul_diag_matrix (self ):
312+ linear_op = self .create_linear_op ()
313+ diag = torch .rand (* linear_op .batch_shape , linear_op .size (- 1 ))
314+ rhs = DiagLinearOperator (diag )
315+ return self ._test_matmul (rhs )
316+
308317 def test_matmul_matrix_broadcast (self ):
309318 linear_op = self .create_linear_op ()
310319
0 commit comments