Skip to content

Commit 5146831

Browse files
committed
More tests, small fixes
1 parent 81062b5 commit 5146831

File tree

5 files changed

+157
-33
lines changed

5 files changed

+157
-33
lines changed

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,11 @@ end
350350

351351
struct BlockType{T} end
352352
BlockType(x) = BlockType{x}()
353-
function Base.similar(a::AbstractBlockSparseArray, ::BlockType{T}) where {T}
354-
return BlockSparseArray{eltype(T),ndims(T),T}(undef, axes(a))
353+
function Base.similar(a::AbstractBlockSparseArray, ::BlockType{T}, ax) where {T}
354+
return BlockSparseArray{eltype(T),ndims(T),T}(undef, ax)
355+
end
356+
function Base.similar(a::AbstractBlockSparseArray, T::BlockType)
357+
return similar(a, T, axes(a))
355358
end
356359

357360
# TODO: Implement this in a more generic way using a smarter `copyto!`,

src/factorizations/eig.jl

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ for f in [:eig_vals!, :eigh_vals!]
5555
function MatrixAlgebraKit.initialize_output(
5656
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
5757
)
58-
return similar(A, axes(A, 1))
58+
T = Base.promote_op($f, blocktype(A), typeof(alg.alg))
59+
return similar(A, BlockType(T), axes(A, 1))
5960
end
6061
function MatrixAlgebraKit.$f(
6162
A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm
@@ -67,24 +68,3 @@ for f in [:eig_vals!, :eigh_vals!]
6768
end
6869
end
6970
end
70-
71-
const TBlockDV = Tuple{AbstractBlockSparseMatrix,AbstractBlockSparseMatrix}
72-
73-
for f in [:eig_trunc!, :eigh_trunc!]
74-
@eval begin
75-
function MatrixAlgebraKit.truncate!(
76-
::typeof($f), (D, V)::TBlockDV, strategy::TruncationStrategy
77-
)
78-
return MatrixAlgebraKit.truncate!(
79-
$f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy)
80-
)
81-
end
82-
function MatrixAlgebraKit.truncate!(
83-
::typeof($f), (D, V)::TBlockDV, strategy::BlockPermutedDiagonalTruncationStrategy
84-
)
85-
d = diagview(D)
86-
ind = findtruncated(d, strategy)
87-
return diagonal(d[ind]), V[:, ind]
88-
end
89-
end
90-
end

src/factorizations/truncation.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc!
1+
using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc!
22

33
function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T}
44
D = BlockSparseVector{T}(undef, axes(A, 1))
@@ -21,18 +21,29 @@ struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: Truncat
2121
strategy::T
2222
end
2323

24-
const TBlockUSVᴴ = Tuple{
25-
AbstractBlockSparseMatrix,AbstractBlockSparseMatrix,AbstractBlockSparseMatrix
26-
}
27-
2824
function MatrixAlgebraKit.truncate!(
29-
::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy
25+
::typeof(svd_trunc!),
26+
(U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix},
27+
strategy::TruncationStrategy,
3028
)
3129
# TODO assert blockdiagonal
3230
return MatrixAlgebraKit.truncate!(
3331
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
3432
)
3533
end
34+
for f in [:eig_trunc!, :eigh_trunc!]
35+
@eval begin
36+
function MatrixAlgebraKit.truncate!(
37+
::typeof($f),
38+
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
39+
strategy::TruncationStrategy,
40+
)
41+
return MatrixAlgebraKit.truncate!(
42+
$f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy)
43+
)
44+
end
45+
end
46+
end
3647

3748
# cannot use regular slicing here: I want to slice without altering blockstructure
3849
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
@@ -47,9 +58,21 @@ end
4758

4859
function MatrixAlgebraKit.truncate!(
4960
::typeof(svd_trunc!),
50-
(U, S, Vᴴ)::TBlockUSV,
61+
(U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix},
5162
strategy::BlockPermutedDiagonalTruncationStrategy,
5263
)
5364
I = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
5465
return (U[:, I], S[I, I], Vᴴ[I, :])
5566
end
67+
for f in [:eig_trunc!, :eigh_trunc!]
68+
@eval begin
69+
function MatrixAlgebraKit.truncate!(
70+
::typeof($f),
71+
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
72+
strategy::BlockPermutedDiagonalTruncationStrategy,
73+
)
74+
I = MatrixAlgebraKit.findtruncated(diagview(D), strategy)
75+
return (D[I, I], V[:, I])
76+
end
77+
end
78+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1414
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
15+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1516
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1617
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/test_factorizations.jl

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
2-
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
2+
using BlockSparseArrays:
3+
BlockSparseArray, BlockDiagonal, blockstoredlength, eachblockstoredindex
34
using MatrixAlgebraKit:
5+
diagview,
6+
eig_full,
7+
eig_trunc,
8+
eig_vals,
9+
eigh_full,
10+
eigh_trunc,
11+
eigh_vals,
412
left_orth,
513
left_polar,
614
lq_compact,
@@ -14,8 +22,9 @@ using MatrixAlgebraKit:
1422
svd_trunc,
1523
truncrank,
1624
trunctol
17-
using LinearAlgebra: LinearAlgebra
25+
using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart
1826
using Random: Random
27+
using StableRNGs: StableRNG
1928
using Test: @inferred, @testset, @test
2029

2130
function test_svd(a, (U, S, Vᴴ); full=false)
@@ -273,3 +282,111 @@ end
273282
@test size(U, 1) 2
274283
@test Matrix(U * U') LinearAlgebra.I
275284
end
285+
286+
@testset "eig_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
287+
A = BlockSparseArray{T}(undef, ([2, 3], [2, 3]))
288+
rng = StableRNG(123)
289+
A[Block(1, 1)] = randn(rng, T, 2, 2)
290+
A[Block(2, 2)] = randn(rng, T, 3, 3)
291+
292+
D, V = eig_full(A)
293+
@test size(D) == size(A)
294+
@test size(D) == size(A)
295+
@test blockstoredlength(D) == 2
296+
@test blockstoredlength(V) == 2
297+
@test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)])
298+
@test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)])
299+
@test A * V V * D
300+
end
301+
302+
@testset "eig_vals (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
303+
A = BlockSparseArray{T}(undef, ([2, 3], [2, 3]))
304+
rng = StableRNG(123)
305+
A[Block(1, 1)] = randn(rng, T, 2, 2)
306+
A[Block(2, 2)] = randn(rng, T, 3, 3)
307+
308+
D = eig_vals(A)
309+
@test size(D) == (size(A, 1),)
310+
@test blockstoredlength(D) == 2
311+
D′ = eig_vals(Matrix(A))
312+
@test sort(D; by=abs) sort(D′; by=abs)
313+
end
314+
315+
@testset "eig_trunc (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
316+
A = BlockSparseArray{T}(undef, ([2, 3], [2, 3]))
317+
rng = StableRNG(123)
318+
D1 = [1.0, 0.1]
319+
V1 = randn(rng, T, 2, 2)
320+
A1 = V1 * Diagonal(D1) * inv(V1)
321+
D2 = [1.0, 0.5, 0.1]
322+
V2 = randn(rng, T, 3, 3)
323+
A2 = V2 * Diagonal(D2) * inv(V2)
324+
A[Block(1, 1)] = A1
325+
A[Block(2, 2)] = A2
326+
327+
D, V = eig_trunc(A; trunc=(; maxrank=3))
328+
@test size(D) == (3, 3)
329+
@test size(D) == (3, 3)
330+
@test blockstoredlength(D) == 2
331+
@test blockstoredlength(V) == 2
332+
@test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)])
333+
@test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)])
334+
@test A * V V * D
335+
@test sort(diagview(D[Block(1, 1)]); by=abs, rev=true) D1[1:1]
336+
@test sort(diagview(D[Block(2, 2)]); by=abs, rev=true) D2[1:2]
337+
end
338+
339+
herm(x) = parent(hermitianpart(x))
340+
341+
@testset "eigh_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
342+
A = BlockSparseArray{T}(undef, ([2, 3], [2, 3]))
343+
rng = StableRNG(123)
344+
A[Block(1, 1)] = herm(randn(rng, T, 2, 2))
345+
A[Block(2, 2)] = herm(randn(rng, T, 3, 3))
346+
347+
D, V = eigh_full(A)
348+
@test size(D) == size(A)
349+
@test size(D) == size(A)
350+
@test blockstoredlength(D) == 2
351+
@test blockstoredlength(V) == 2
352+
@test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)])
353+
@test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)])
354+
@test A * V V * D
355+
end
356+
357+
@testset "eigh_vals (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
358+
A = BlockSparseArray{T}(undef, ([2, 3], [2, 3]))
359+
rng = StableRNG(123)
360+
A[Block(1, 1)] = herm(randn(rng, T, 2, 2))
361+
A[Block(2, 2)] = herm(randn(rng, T, 3, 3))
362+
363+
D = eigh_vals(A)
364+
@test size(D) == (size(A, 1),)
365+
@test blockstoredlength(D) == 2
366+
D′ = eigh_vals(Matrix(A))
367+
@test sort(D; by=abs) sort(D′; by=abs)
368+
end
369+
370+
@testset "eigh_trunc (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
371+
A = BlockSparseArray{T}(undef, ([2, 3], [2, 3]))
372+
rng = StableRNG(123)
373+
D1 = [1.0, 0.1]
374+
V1, _ = qr_compact(randn(rng, T, 2, 2))
375+
A1 = V1 * Diagonal(D1) * V1'
376+
D2 = [1.0, 0.5, 0.1]
377+
V2, _ = qr_compact(randn(rng, T, 3, 3))
378+
A2 = V2 * Diagonal(D2) * V2'
379+
A[Block(1, 1)] = herm(A1)
380+
A[Block(2, 2)] = herm(A2)
381+
382+
D, V = eigh_trunc(A; trunc=(; maxrank=3))
383+
@test size(D) == (3, 3)
384+
@test size(D) == (3, 3)
385+
@test blockstoredlength(D) == 2
386+
@test blockstoredlength(V) == 2
387+
@test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)])
388+
@test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)])
389+
@test A * V V * D
390+
@test sort(diagview(D[Block(1, 1)]); by=abs, rev=true) D1[1:1]
391+
@test sort(diagview(D[Block(2, 2)]); by=abs, rev=true) D2[1:2]
392+
end

0 commit comments

Comments
 (0)