|
17 | 17 | blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
|
18 | 18 | blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]
|
19 | 19 |
|
20 |
| - @testset "$T" for (T, (b1, b2, b3)) in ( |
21 |
| - Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))), |
22 |
| - Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)), |
23 |
| - ) |
24 |
| - A = rand(rng, N, N + N1) |
25 |
| - B = rand(rng, N + N1, N + N2) |
26 |
| - A′, B′ = A', B' |
27 |
| - a = rand(rng, N) |
28 |
| - b = rand(rng, N + N1) |
| 20 | + @testset for V in (Tuple, Vector) |
| 21 | + b1 = BlockDiagonal(V(blocks1)) |
| 22 | + b2 = BlockDiagonal(V(blocks2)) |
| 23 | + N = size(b1, 1) |
29 | 24 |
|
30 | 25 | @testset "AbstractArray" begin
|
31 |
| - X = rand(2, 2); Y = rand(3, 3) |
| 26 | + X = rand(2, 2) |
| 27 | + Y = rand(3, 3) |
32 | 28 |
|
33 | 29 | @test size(b1) == (N, N)
|
34 | 30 | @test size(b1, 1) == N && size(b1, 2) == N
|
|
53 | 49 | end
|
54 | 50 |
|
55 | 51 | @testset "parent" begin
|
56 |
| - @test parent(b1) isa Union{Tuple,AbstractVector} |
| 52 | + @test parent(b1) isa V |
57 | 53 | @test eltype(parent(b1)) <: AbstractMatrix
|
58 | 54 | @test parent(BlockDiagonal([X, Y])) == [X, Y]
|
59 | 55 | @test parent(BlockDiagonal((X, Y))) == (X, Y)
|
|
66 | 62 | end
|
67 | 63 |
|
68 | 64 | @testset "setindex!" begin
|
69 |
| - X = BlockDiagonal([rand(Float32, 5, 5), rand(Float32, 3, 3)]) |
| 65 | + X = BlockDiagonal(V([rand(Float32, 5, 5), rand(Float32, 3, 3)])) |
70 | 66 | X[10] = Int(10)
|
71 | 67 | @test X[10] === Float32(10.0)
|
72 | 68 | X[3, 3] = Int(9)
|
|
78 | 74 |
|
79 | 75 | @testset "ChainRules" begin
|
80 | 76 | @testset "BlockDiagonal" begin
|
81 |
| - x = [randn(1, 2), randn(2, 2)] |
82 |
| - x̄ = [randn(1, 2), randn(2, 2)] |
83 |
| - ȳ = Composite{typeof(BlockDiagonal(x))}(blocks=[randn(1, 2), randn(2, 2)]) |
| 77 | + x = V([randn(1, 2), randn(2, 2)]) |
| 78 | + x̄ = V([randn(1, 2), randn(2, 2)]) |
| 79 | + |
| 80 | + ȳ = Composite{typeof(BlockDiagonal(x))}(blocks=V([randn(1, 2), randn(2, 2)])) |
84 | 81 | rrule_test(BlockDiagonal, ȳ, (x, x̄))
|
85 | 82 | end
|
86 | 83 | @testset "Matrix" begin
|
87 |
| - D = BlockDiagonal([randn(1, 2), randn(2, 2)]) |
88 |
| - D̄ = Composite{typeof(D)}((blocks=[randn(1, 2), randn(2, 2)]), ) |
| 84 | + D = BlockDiagonal(V([randn(1, 2), randn(2, 2)])) |
| 85 | + D̄ = Composite{typeof(D)}((blocks=V([randn(1, 2), randn(2, 2)])),) |
89 | 86 | Ȳ = randn(size(D))
|
90 | 87 | rrule_test(Matrix, Ȳ, (D, D̄))
|
91 | 88 | end
|
|
98 | 95 | end
|
99 | 96 |
|
100 | 97 | @testset "blocks size" begin
|
101 |
| - B = BlockDiagonal([rand(3, 3), rand(4, 4)]) |
| 98 | + B = BlockDiagonal(V([rand(3, 3), rand(4, 4)])) |
102 | 99 | @test nblocks(B) == 2
|
103 |
| - @test blocksizes(B) == [(3, 3), (4, 4)] |
| 100 | + @test blocksizes(B) == V([(3, 3), (4, 4)]) |
104 | 101 | @test blocksize(B, 2) == blocksizes(B)[2] == blocksize(B, 2, 2)
|
105 | 102 | end
|
106 | 103 |
|
|
124 | 121 | @testset "Non-Square Matrix" begin
|
125 | 122 | A1 = ones(2, 4)
|
126 | 123 | A2 = 2 * ones(3, 2)
|
127 |
| - B1 = BlockDiagonal([A1, A2]) |
128 |
| - B2 = [A1 zeros(2, 2); zeros(3, 4) A2] |
| 124 | + B1 = BlockDiagonal(V([A1, A2])) |
| 125 | + B2 = [A1 zeros(2, 2); zeros(3, 4) A2] |
129 | 126 |
|
130 | 127 | @test B1 == B2
|
131 | 128 | # Dimension check
|
|
0 commit comments