Skip to content

Commit 8c48538

Browse files
committed
Simplify test
1 parent e89fcaf commit 8c48538

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

linear_operator/test/linear_operator_test_case.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,6 @@ def _test_matmul(self, rhs):
5858
actual = evaluated.matmul(rhs)
5959
self.assertAllClose(to_dense(res), actual)
6060

61-
def _test_t_matmul(self, rhs):
62-
with torch.no_grad():
63-
linear_op = self.create_linear_op()
64-
linear_op_copy = torch.clone(linear_op)
65-
evaluated = self.evaluate_linear_op(linear_op_copy)
66-
rhs_evaluated = to_dense(rhs)
67-
68-
# Test operator
69-
res = linear_op._t_matmul(rhs)
70-
actual = evaluated.mT.matmul(rhs_evaluated)
71-
res_evaluated = to_dense(res)
72-
self.assertAllClose(res_evaluated, actual)
73-
7461
def _test_rmatmul(self, lhs):
7562
linear_op = self.create_linear_op().detach().requires_grad_(True)
7663
linear_op_copy = torch.clone(linear_op).detach().requires_grad_(True)
@@ -405,9 +392,18 @@ def test_matmul_matrix(self):
405392
return self._test_matmul(rhs)
406393

407394
def test_t_matmul_matrix(self):
408-
linear_op = self.create_linear_op()
409-
rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-2), 4)
410-
return self._test_t_matmul(rhs)
395+
with torch.no_grad():
396+
linear_op = self.create_linear_op()
397+
rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-2), 4)
398+
linear_op_copy = torch.clone(linear_op)
399+
evaluated = self.evaluate_linear_op(linear_op_copy)
400+
rhs_evaluated = to_dense(rhs)
401+
402+
# Test operator
403+
res = linear_op._t_matmul(rhs)
404+
actual = evaluated.mT.matmul(rhs_evaluated)
405+
res_evaluated = to_dense(res)
406+
self.assertAllClose(res_evaluated, actual)
411407

412408
def test_rmatmul_matrix(self):
413409
linear_op = self.create_linear_op()

0 commit comments

Comments
 (0)