Skip to content

Commit 32f9d07

Browse files
Simplify test setup for base_maths.jl
1 parent ac3ddc2 commit 32f9d07

File tree

1 file changed

+24
-27
lines changed

1 file changed

+24
-27
lines changed

test/base_maths.jl

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,31 @@ using Test
55

66
@testset "base_maths.jl" begin
77
rng = MersenneTwister(123456)
8-
N1, N2, N3 = 3, 4, 5
9-
N = N1 + N2 + N3
10-
blocks1 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)]
11-
blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
12-
blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]
13-
14-
@testset "$T" for (T, (b1, b2, b3)) in (
15-
Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))),
16-
Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)),
17-
)
18-
A = rand(rng, N, N + N1)
19-
B = rand(rng, N + N1, N + N2)
20-
A′, B′ = A', B'
21-
a = rand(rng, N)
22-
b = rand(rng, N + N1)
8+
blocks1 = [rand(rng, 3, 3), rand(rng, 4, 4)]
9+
blocks2 = [rand(rng, 3, 3), rand(rng, 5, 5)]
10+
11+
@testset for V in (Tuple, Vector)
12+
b1 = BlockDiagonal(V(blocks1))
13+
b2 = BlockDiagonal(V(blocks2))
14+
N = size(b1, 1)
15+
A = rand(rng, N, N + 1)
2316

2417
@testset "Addition" begin
2518
@testset "BlockDiagonal + BlockDiagonal" begin
2619
@test b1 + b1 isa BlockDiagonal
2720
@test Matrix(b1 + b1) == Matrix(b1) + Matrix(b1)
28-
@test_throws DimensionMismatch b1 + b3
21+
@test_throws DimensionMismatch b1 + b2
2922
end
3023

3124
@testset "BlockDiagonal + Matrix" begin
3225
@test b1 + Matrix(b1) isa Matrix
3326
@test b1 + Matrix(b1) == b1 + b1
34-
@test_throws DimensionMismatch b1 + Matrix(b3)
27+
@test_throws DimensionMismatch b1 + Matrix(b2)
3528

3629
# Matrix + BlockDiagonal
3730
@test Matrix(b1) + b1 isa Matrix
3831
@test Matrix(b1) + b1 == b1 + b1
39-
@test_throws DimensionMismatch Matrix(b1) + b3
32+
@test_throws DimensionMismatch Matrix(b1) + b2
4033

4134
# If the AbstractMatrix is diagonal, we should return a BlockDiagonal.
4235
# Test the StridedMatrix method.
@@ -50,7 +43,7 @@ using Test
5043

5144
@testset "BlockDiagonal + Diagonal" begin
5245
D = Diagonal(randn(rng, N))
53-
D′ = Diagonal(randn(rng, N + N1))
46+
D′ = Diagonal(randn(rng, N + 1))
5447

5548
@test b1 + D isa BlockDiagonal
5649
@test b1 + D == Matrix(b1) + D
@@ -73,11 +66,10 @@ using Test
7366
end # Addition
7467

7568
@testset "Multiplication" begin
76-
7769
@testset "BlockDiagonal * BlockDiagonal" begin
7870
@test b1 * b1 isa BlockDiagonal
7971
@test Matrix(b1 * b1) Matrix(b1) * Matrix(b1)
80-
@test_throws DimensionMismatch b3 * b1
72+
@test_throws DimensionMismatch b2 * b1
8173
end
8274

8375
@testset "BlockDiagonal * Number" begin
@@ -88,11 +80,14 @@ using Test
8880
end
8981

9082
@testset "BlockDiagonal * Vector" begin
83+
a = rand(rng, N)
9184
@test b1 * a isa Vector
9285
@test b1 * a Matrix(b1) * a
86+
b = rand(rng, N + 1)
9387
@test_throws DimensionMismatch b1 * b
9488
end
9589
@testset "Vector^T * BlockDiagonal" begin
90+
a = rand(rng, N)
9691
@test a' * b1 isa Adjoint{<:Number, <:Vector}
9792
@test transpose(a) * b1 isa Transpose{<:Number, <:Vector}
9893
@test a' * b1 a' * Matrix(b1)
@@ -102,11 +97,13 @@ using Test
10297
@testset "BlockDiagonal * Matrix" begin
10398
@test b1 * A isa Matrix
10499
@test b1 * A Matrix(b1) * A
100+
101+
B = rand(rng, N + 1, N)
105102
@test_throws DimensionMismatch b1 * B
106103

107104
# Matrix * BlockDiagonal
108-
@test A * b1 isa Matrix
109-
@test A * b1 A * Matrix(b1)
105+
@test A' * b1 isa Matrix
106+
@test A' * b1 A' * Matrix(b1)
110107
@test_throws DimensionMismatch A * b1
111108

112109
# degenerate cases
@@ -119,7 +116,7 @@ using Test
119116

120117
@testset "BlockDiagonal * Diagonal" begin
121118
D = Diagonal(randn(rng, N))
122-
D′ = Diagonal(randn(rng, N + N1))
119+
D′ = Diagonal(randn(rng, N + 1))
123120

124121
@test b1 * D isa BlockDiagonal
125122
@test b1 * D Matrix(b1) * D
@@ -132,8 +129,8 @@ using Test
132129
end
133130

134131
@testset "Non-Square BlockDiagonal * Non-Square BlockDiagonal" begin
135-
b4 = BlockDiagonal(T([ones(2, 4), 2 * ones(3, 2)]))
136-
b5 = BlockDiagonal(T([3 * ones(2, 2), 2 * ones(4, 1)]))
132+
b4 = BlockDiagonal(V([ones(2, 4), 2 * ones(3, 2)]))
133+
b5 = BlockDiagonal(V([3 * ones(2, 2), 2 * ones(4, 1)]))
137134

138135
@test b4 * b5 isa Array
139136
@test b4 * b5 == [6 * ones(2, 2) 4 * ones(2, 1); zeros(3, 2) 8 * ones(3, 1)]
@@ -142,5 +139,5 @@ using Test
142139
@test sum(size.(b5.blocks, 2)) == size(b4 * b5, 2)
143140
end
144141
end # Multiplication
145-
end
142+
end # V
146143
end

0 commit comments

Comments
 (0)