Skip to content

Commit 5477835

Browse files
Improve tests in blockdiagonal.jl
1 parent 9bf2360 commit 5477835

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

test/blockdiagonal.jl

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,14 @@ end
1717
blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
1818
blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]
1919

20-
@testset "$T" for (T, (b1, b2, b3)) in (
21-
Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))),
22-
Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)),
23-
)
24-
A = rand(rng, N, N + N1)
25-
B = rand(rng, N + N1, N + N2)
26-
A′, B′ = A', B'
27-
a = rand(rng, N)
28-
b = rand(rng, N + N1)
20+
@testset for V in (Tuple, Vector)
21+
b1 = BlockDiagonal(V(blocks1))
22+
b2 = BlockDiagonal(V(blocks2))
23+
N = size(b1, 1)
2924

3025
@testset "AbstractArray" begin
31-
X = rand(2, 2); Y = rand(3, 3)
26+
X = rand(2, 2)
27+
Y = rand(3, 3)
3228

3329
@test size(b1) == (N, N)
3430
@test size(b1, 1) == N && size(b1, 2) == N
@@ -53,7 +49,7 @@ end
5349
end
5450

5551
@testset "parent" begin
56-
@test parent(b1) isa Union{Tuple,AbstractVector}
52+
@test parent(b1) isa V
5753
@test eltype(parent(b1)) <: AbstractMatrix
5854
@test parent(BlockDiagonal([X, Y])) == [X, Y]
5955
@test parent(BlockDiagonal((X, Y))) == (X, Y)
@@ -66,7 +62,7 @@ end
6662
end
6763

6864
@testset "setindex!" begin
69-
X = BlockDiagonal([rand(Float32, 5, 5), rand(Float32, 3, 3)])
65+
X = BlockDiagonal(V([rand(Float32, 5, 5), rand(Float32, 3, 3)]))
7066
X[10] = Int(10)
7167
@test X[10] === Float32(10.0)
7268
X[3, 3] = Int(9)
@@ -78,14 +74,15 @@ end
7874

7975
@testset "ChainRules" begin
8076
@testset "BlockDiagonal" begin
81-
x = [randn(1, 2), randn(2, 2)]
82-
= [randn(1, 2), randn(2, 2)]
83-
= Composite{typeof(BlockDiagonal(x))}(blocks=[randn(1, 2), randn(2, 2)])
77+
x = V([randn(1, 2), randn(2, 2)])
78+
= V([randn(1, 2), randn(2, 2)])
79+
80+
= Composite{typeof(BlockDiagonal(x))}(blocks=V([randn(1, 2), randn(2, 2)]))
8481
rrule_test(BlockDiagonal, ȳ, (x, x̄))
8582
end
8683
@testset "Matrix" begin
87-
D = BlockDiagonal([randn(1, 2), randn(2, 2)])
88-
= Composite{typeof(D)}((blocks=[randn(1, 2), randn(2, 2)]), )
84+
D = BlockDiagonal(V([randn(1, 2), randn(2, 2)]))
85+
= Composite{typeof(D)}((blocks=V([randn(1, 2), randn(2, 2)])),)
8986
= randn(size(D))
9087
rrule_test(Matrix, Ȳ, (D, D̄))
9188
end
@@ -98,9 +95,9 @@ end
9895
end
9996

10097
@testset "blocks size" begin
101-
B = BlockDiagonal([rand(3, 3), rand(4, 4)])
98+
B = BlockDiagonal(V([rand(3, 3), rand(4, 4)]))
10299
@test nblocks(B) == 2
103-
@test blocksizes(B) == [(3, 3), (4, 4)]
100+
@test blocksizes(B) == V([(3, 3), (4, 4)])
104101
@test blocksize(B, 2) == blocksizes(B)[2] == blocksize(B, 2, 2)
105102
end
106103

@@ -124,8 +121,8 @@ end
124121
@testset "Non-Square Matrix" begin
125122
A1 = ones(2, 4)
126123
A2 = 2 * ones(3, 2)
127-
B1 = BlockDiagonal([A1, A2])
128-
B2 = [A1 zeros(2, 2); zeros(3, 4) A2]
124+
B1 = BlockDiagonal(V([A1, A2]))
125+
B2 = [A1 zeros(2, 2); zeros(3, 4) A2]
129126

130127
@test B1 == B2
131128
# Dimension check

0 commit comments

Comments
 (0)