Skip to content

Commit 6ea8866

Browse files
authored
Merge pull request #72 from Turakar/fix_masked_t_matmul
Fix _t_matmul() in MaskedLinearOperator and add test
2 parents e915cc0 + 8c48538 commit 6ea8866

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

linear_operator/operators/masked_linear_operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _matmul(
5050
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
5151
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
5252
rhs_expanded = self._expand(rhs, self.col_mask)
53-
res_expanded = self.base.matmul(rhs_expanded)
53+
res_expanded = self.base._matmul(rhs_expanded)
5454
res = res_expanded[..., self.row_mask, :]
5555

5656
return res
@@ -60,7 +60,7 @@ def _t_matmul(
6060
rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]],
6161
) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]:
6262
rhs_expanded = self._expand(rhs, self.row_mask)
63-
res_expanded = self.base.t_matmul(rhs_expanded)
63+
res_expanded = self.base._t_matmul(rhs_expanded)
6464
res = res_expanded[..., self.col_mask, :]
6565
return res
6666

linear_operator/test/linear_operator_test_case.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,20 @@ def test_matmul_matrix(self):
391391
rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-1), 4)
392392
return self._test_matmul(rhs)
393393

394+
def test_t_matmul_matrix(self):
395+
with torch.no_grad():
396+
linear_op = self.create_linear_op()
397+
rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-2), 4)
398+
linear_op_copy = torch.clone(linear_op)
399+
evaluated = self.evaluate_linear_op(linear_op_copy)
400+
rhs_evaluated = to_dense(rhs)
401+
402+
# Test operator
403+
res = linear_op._t_matmul(rhs)
404+
actual = evaluated.mT.matmul(rhs_evaluated)
405+
res_evaluated = to_dense(res)
406+
self.assertAllClose(res_evaluated, actual)
407+
394408
def test_rmatmul_matrix(self):
395409
linear_op = self.create_linear_op()
396410
lhs = torch.randn(*linear_op.batch_shape, 4, linear_op.size(-2))

0 commit comments

Comments
 (0)