Skip to content

Commit 17ccf8e

Browse files
committed
Add tests
1 parent 945c660 commit 17ccf8e

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

test/test_factorizations.jl

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
3-
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank
3+
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank, trunctol
44
using LinearAlgebra: LinearAlgebra
55
using Random: Random
66
using Test: @inferred, @testset, @test
@@ -88,7 +88,6 @@ end
8888
# ----------
8989

9090
@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params
91-
(m, n), T = first(test_params)
9291
a = BlockSparseArray{T}(undef, m, n)
9392

9493
# test blockdiagonal
@@ -99,9 +98,10 @@ end
9998

10099
minmn = min(size(a)...)
101100
r = max(1, minmn - 2)
101+
trunc = truncrank(r)
102102

103-
U1, S1, V1ᴴ = svd_trunc(a; trunc=truncrank(r))
104-
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc=truncrank(r))
103+
U1, S1, V1ᴴ = svd_trunc(a; trunc)
104+
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc)
105105
@test size(U1) == size(U2)
106106
@test size(S1) == size(S2)
107107
@test size(V1ᴴ) == size(V2ᴴ)
@@ -110,11 +110,11 @@ end
110110
@test (U1' * U1 LinearAlgebra.I)
111111
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
112112

113-
# test permuted blockdiagonal
114-
perm = Random.randperm(length(m))
115-
b = a[Block.(perm), Block.(1:length(n))]
116-
U1, S1, V1ᴴ = svd_trunc(b; trunc=truncrank(r))
117-
U2, S2, V2ᴴ = svd_trunc(Matrix(b); trunc=truncrank(r))
113+
atol = minimum(LinearAlgebra.diag(S1)) + 10 * eps(real(T))
114+
trunc = trunctol(atol)
115+
116+
U1, S1, V1ᴴ = svd_trunc(a; trunc)
117+
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc)
118118
@test size(U1) == size(U2)
119119
@test size(S1) == size(S2)
120120
@test size(V1ᴴ) == size(V2ᴴ)
@@ -123,17 +123,34 @@ end
123123
@test (U1' * U1 LinearAlgebra.I)
124124
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
125125

126+
# test permuted blockdiagonal
127+
perm = Random.randperm(length(m))
128+
b = a[Block.(perm), Block.(1:length(n))]
129+
for trunc in (truncrank(r), trunctol(atol))
130+
U1, S1, V1ᴴ = svd_trunc(b; trunc)
131+
U2, S2, V2ᴴ = svd_trunc(Matrix(b); trunc)
132+
@test size(U1) == size(U2)
133+
@test size(S1) == size(S2)
134+
@test size(V1ᴴ) == size(V2ᴴ)
135+
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
136+
137+
@test (U1' * U1 LinearAlgebra.I)
138+
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
139+
end
140+
126141
# test permuted blockdiagonal with missing row/col
127142
I_removed = rand(eachblockstoredindex(b))
128143
c = copy(b)
129144
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
130-
U1, S1, V1ᴴ = svd_trunc(c; trunc=truncrank(r))
131-
U2, S2, V2ᴴ = svd_trunc(Matrix(c); trunc=truncrank(r))
132-
@test size(U1) == size(U2)
133-
@test size(S1) == size(S2)
134-
@test size(V1ᴴ) == size(V2ᴴ)
135-
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
136-
137-
@test (U1' * U1 LinearAlgebra.I)
138-
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
145+
for trunc in (truncrank(r), trunctol(atol))
146+
U1, S1, V1ᴴ = svd_trunc(c; trunc)
147+
U2, S2, V2ᴴ = svd_trunc(Matrix(c); trunc)
148+
@test size(U1) == size(U2)
149+
@test size(S1) == size(S2)
150+
@test size(V1ᴴ) == size(V2ᴴ)
151+
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
152+
153+
@test (U1' * U1 LinearAlgebra.I)
154+
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
155+
end
139156
end

0 commit comments

Comments
 (0)