Skip to content

Commit 2101877

Browse files
authored
Merge pull request #119 from kayween/fix-block-diag-linear-operator-diagonal
Convert `base_linear_op` to a dense linear operator in `BlockDiagLinearOperator`
2 parents dbed373 + 86092eb commit 2101877

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

linear_operator/operators/block_diag_linear_operator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from linear_operator.operators._linear_operator import IndexType, LinearOperator
1111
from linear_operator.operators.block_linear_operator import BlockLinearOperator
12+
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
1213

1314
from linear_operator.utils.memoize import cached
1415

@@ -49,6 +50,9 @@ class BlockDiagLinearOperator(BlockLinearOperator, metaclass=_MetaBlockDiagLinea
4950
"""
5051

5152
def __init__(self, base_linear_op, block_dim=-3):
53+
if isinstance(base_linear_op, Tensor):
54+
base_linear_op = DenseLinearOperator(base_linear_op)
55+
5256
super().__init__(base_linear_op, block_dim)
5357
# block diagonal is restricted to have square diagonal blocks
5458
if self.base_linear_op.shape[-1] != self.base_linear_op.shape[-2]:

test/operators/test_block_diag_linear_operator.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,19 @@ class TestBlockDiagLinearOperator(LinearOperatorTestCase, unittest.TestCase):
1212
seed = 0
1313
should_test_sample = True
1414

15+
# Whether to initialize `BlockDiagLinearOperator` from a tensor or a linear operator.
16+
_initialize_from_tensor = False
17+
1518
def create_linear_op(self):
1619
blocks = torch.randn(8, 4, 4)
1720
blocks = blocks.matmul(blocks.mT)
1821
blocks.add_(torch.eye(4, 4).unsqueeze_(0))
19-
return BlockDiagLinearOperator(DenseLinearOperator(blocks))
22+
23+
return (
24+
BlockDiagLinearOperator(blocks)
25+
if self._initialize_from_tensor
26+
else BlockDiagLinearOperator(DenseLinearOperator(blocks))
27+
)
2028

2129
def evaluate_linear_op(self, linear_op):
2230
blocks = linear_op.base_linear_op.tensor
@@ -26,6 +34,10 @@ def evaluate_linear_op(self, linear_op):
2634
return actual
2735

2836

37+
class TestBlockDiagLinearOperatorFromTensor(TestBlockDiagLinearOperator):
38+
_initialize_from_tensor = True
39+
40+
2941
class TestBlockDiagLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase):
3042
seed = 0
3143
should_test_sample = True
@@ -75,7 +87,7 @@ def test_metaclass_constructor(self):
7587
base_operators = [torch.randn(k, n), torch.randn(b1, b2, k, n)]
7688
subtest_names = ["non-batched input", "batched input"]
7789
# repeats tests for both batched and non-batched tensors
78-
for (base_op, test_name) in zip(base_operators, subtest_names):
90+
for base_op, test_name in zip(base_operators, subtest_names):
7991
with self.subTest(test_name):
8092
base_diag = DiagLinearOperator(base_op)
8193
linear_op = BlockDiagLinearOperator(base_diag)

0 commit comments

Comments
 (0)