@@ -12,11 +12,19 @@ class TestBlockDiagLinearOperator(LinearOperatorTestCase, unittest.TestCase):
1212 seed = 0
1313 should_test_sample = True
1414
15+ # Whether to initialize `BlockDiagLinearOperator` from a tensor or a linear operator.
16+ _initialize_from_tensor = False
17+
1518 def create_linear_op (self ):
1619 blocks = torch .randn (8 , 4 , 4 )
1720 blocks = blocks .matmul (blocks .mT )
1821 blocks .add_ (torch .eye (4 , 4 ).unsqueeze_ (0 ))
19- return BlockDiagLinearOperator (DenseLinearOperator (blocks ))
22+
23+ return (
24+ BlockDiagLinearOperator (blocks )
25+ if self ._initialize_from_tensor
26+ else BlockDiagLinearOperator (DenseLinearOperator (blocks ))
27+ )
2028
2129 def evaluate_linear_op (self , linear_op ):
2230 blocks = linear_op .base_linear_op .tensor
@@ -26,6 +34,10 @@ def evaluate_linear_op(self, linear_op):
2634 return actual
2735
2836
37+ class TestBlockDiagLinearOperatorFromTensor (TestBlockDiagLinearOperator ):
38+ _initialize_from_tensor = True
39+
40+
2941class TestBlockDiagLinearOperatorBatch (LinearOperatorTestCase , unittest .TestCase ):
3042 seed = 0
3143 should_test_sample = True
@@ -75,7 +87,7 @@ def test_metaclass_constructor(self):
7587 base_operators = [torch .randn (k , n ), torch .randn (b1 , b2 , k , n )]
7688 subtest_names = ["non-batched input" , "batched input" ]
7789 # repeats tests for both batched and non-batched tensors
78- for ( base_op , test_name ) in zip (base_operators , subtest_names ):
90+ for base_op , test_name in zip (base_operators , subtest_names ):
7991 with self .subTest (test_name ):
8092 base_diag = DiagLinearOperator (base_op )
8193 linear_op = BlockDiagLinearOperator (base_diag )
0 commit comments