@@ -58,19 +58,6 @@ def _test_matmul(self, rhs):
5858 actual = evaluated .matmul (rhs )
5959 self .assertAllClose (to_dense (res ), actual )
6060
61- def _test_t_matmul (self , rhs ):
62- with torch .no_grad ():
63- linear_op = self .create_linear_op ()
64- linear_op_copy = torch .clone (linear_op )
65- evaluated = self .evaluate_linear_op (linear_op_copy )
66- rhs_evaluated = to_dense (rhs )
67-
68- # Test operator
69- res = linear_op ._t_matmul (rhs )
70- actual = evaluated .mT .matmul (rhs_evaluated )
71- res_evaluated = to_dense (res )
72- self .assertAllClose (res_evaluated , actual )
73-
7461 def _test_rmatmul (self , lhs ):
7562 linear_op = self .create_linear_op ().detach ().requires_grad_ (True )
7663 linear_op_copy = torch .clone (linear_op ).detach ().requires_grad_ (True )
@@ -405,9 +392,18 @@ def test_matmul_matrix(self):
405392 return self ._test_matmul (rhs )
406393
407394 def test_t_matmul_matrix (self ):
408- linear_op = self .create_linear_op ()
409- rhs = torch .randn (* linear_op .batch_shape , linear_op .size (- 2 ), 4 )
410- return self ._test_t_matmul (rhs )
395+ with torch .no_grad ():
396+ linear_op = self .create_linear_op ()
397+ rhs = torch .randn (* linear_op .batch_shape , linear_op .size (- 2 ), 4 )
398+ linear_op_copy = torch .clone (linear_op )
399+ evaluated = self .evaluate_linear_op (linear_op_copy )
400+ rhs_evaluated = to_dense (rhs )
401+
402+ # Test operator
403+ res = linear_op ._t_matmul (rhs )
404+ actual = evaluated .mT .matmul (rhs_evaluated )
405+ res_evaluated = to_dense (res )
406+ self .assertAllClose (res_evaluated , actual )
411407
412408 def test_rmatmul_matrix (self ):
413409 linear_op = self .create_linear_op ()
0 commit comments