diff --git a/Project.toml b/Project.toml index 852f72ae..8922533a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.1" +version = "0.7.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index 6470fc61..ec996e06 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -51,5 +51,6 @@ include("factorizations/qr.jl") include("factorizations/lq.jl") include("factorizations/polar.jl") include("factorizations/orthnull.jl") +include("factorizations/eig.jl") end diff --git a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index c22907b8..b72f1e41 100644 --- a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -348,6 +348,15 @@ function Base.similar( return @interface BlockSparseArrayInterface() similar(a, elt, axes) end +struct BlockType{T} end +BlockType(x) = BlockType{x}() +function Base.similar(a::AbstractBlockSparseArray, ::BlockType{T}, ax) where {T} + return BlockSparseArray{eltype(T),ndims(T),T}(undef, ax) +end +function Base.similar(a::AbstractBlockSparseArray, T::BlockType) + return similar(a, T, axes(a)) +end + # TODO: Implement this in a more generic way using a smarter `copyto!`, # which is ultimately what `Array{T,N}(::AbstractArray{<:Any,N})` calls. # These are defined for now to avoid scalar indexing issues when there diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index d11f604c..9025c8f2 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -45,6 +45,9 @@ end function eachstoredblockdiagindex(a::AbstractArray) return eachblockstoredindex(a) ∩ blockdiagindices(a) end +function eachunstoredblockdiagindex(a::AbstractArray) + return setdiff(blockdiagindices(a), eachblockstoredindex(a)) +end # Like `BlockArrays.eachblock` but only iterating # over stored blocks. diff --git a/src/factorizations/eig.jl b/src/factorizations/eig.jl new file mode 100644 index 00000000..82358f48 --- /dev/null +++ b/src/factorizations/eig.jl @@ -0,0 +1,92 @@ +using BlockArrays: blocksizes +using DiagonalArrays: diagonal +using LinearAlgebra: LinearAlgebra, Diagonal +using MatrixAlgebraKit: + MatrixAlgebraKit, + TruncationStrategy, + check_input, + default_eig_algorithm, + default_eigh_algorithm, + diagview, + eig_full!, + eig_trunc!, + eig_vals!, + eigh_full!, + eigh_trunc!, + eigh_vals!, + findtruncated + +for f in [:default_eig_algorithm, :default_eigh_algorithm] + @eval begin + function MatrixAlgebraKit.$f(arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...) + alg = $f(blocktype(arrayt); kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) + end + end +end + +function MatrixAlgebraKit.check_input( + ::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V) +) + @assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix) + @assert eltype(V) === eltype(D) === complex(eltype(A)) + @assert axes(A, 1) == axes(A, 2) + @assert axes(A) == axes(D) == axes(V) + return nothing +end +function MatrixAlgebraKit.check_input( + ::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V) +) + @assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix) + @assert eltype(V) === eltype(A) + @assert eltype(D) === real(eltype(A)) + @assert axes(A, 1) == axes(A, 2) + @assert axes(A) == axes(D) == axes(V) + return nothing +end + +for f in [:eig_full!, :eigh_full!] + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm + ) + Td, Tv = fieldtypes(Base.promote_op($f, blocktype(A), typeof(alg.alg))) + D = similar(A, BlockType(Td)) + V = similar(A, BlockType(Tv)) + return (D, V) + end + function MatrixAlgebraKit.$f( + A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm + ) + check_input($f, A, (D, V)) + for I in eachstoredblockdiagindex(A) + D[I], V[I] = $f(@view(A[I]), alg.alg) + end + for I in eachunstoredblockdiagindex(A) + # TODO: Support setting `LinearAlgebra.I` directly, and/or + # using `FillArrays.Eye`. + V[I] = LinearAlgebra.I(size(@view(V[I]), 1)) + end + return (D, V) + end + end +end + +for f in [:eig_vals!, :eigh_vals!] + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm + ) + T = Base.promote_op($f, blocktype(A), typeof(alg.alg)) + return similar(A, BlockType(T), axes(A, 1)) + end + function MatrixAlgebraKit.$f( + A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm + ) + for I in eachblockstoredindex(A) + D[I] = $f(@view!(A[I]), alg.alg) + end + return D + end + end +end diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index 673e5763..1f8f4a42 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -1,4 +1,5 @@ -using MatrixAlgebraKit: MatrixAlgebraKit, default_svd_algorithm, svd_compact!, svd_full! +using MatrixAlgebraKit: + MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full! """ BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm) @@ -152,45 +153,40 @@ function MatrixAlgebraKit.initialize_output( end function MatrixAlgebraKit.check_input( - ::typeof(svd_compact!), A::AbstractBlockSparseMatrix, USVᴴ + ::typeof(svd_compact!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ) ) - U, S, Vt = USVᴴ @assert isa(U, AbstractBlockSparseMatrix) && isa(S, AbstractBlockSparseMatrix) && - isa(Vt, AbstractBlockSparseMatrix) - @assert eltype(A) == eltype(U) == eltype(Vt) + isa(Vᴴ, AbstractBlockSparseMatrix) + @assert eltype(A) == eltype(U) == eltype(Vᴴ) @assert real(eltype(A)) == eltype(S) - @assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vt, 2) + @assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 2) @assert axes(S, 1) == axes(S, 2) - return nothing end function MatrixAlgebraKit.check_input( - ::typeof(svd_full!), A::AbstractBlockSparseMatrix, USVᴴ + ::typeof(svd_full!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ) ) - U, S, Vt = USVᴴ @assert isa(U, AbstractBlockSparseMatrix) && isa(S, AbstractBlockSparseMatrix) && - isa(Vt, AbstractBlockSparseMatrix) - @assert eltype(A) == eltype(U) == eltype(Vt) + isa(Vᴴ, AbstractBlockSparseMatrix) + @assert eltype(A) == eltype(U) == eltype(Vᴴ) @assert real(eltype(A)) == eltype(S) - @assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vt, 1) == axes(Vt, 2) + @assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 1) == axes(Vᴴ, 2) @assert axes(S, 2) == axes(A, 2) - return nothing end function MatrixAlgebraKit.svd_compact!( - A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm + A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm ) - MatrixAlgebraKit.check_input(svd_compact!, A, USVᴴ) - U, S, Vt = USVᴴ + check_input(svd_compact!, A, (U, S, Vᴴ)) # do decomposition on each block for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) - usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt[bcol, bcol])) + usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol])) usvᴴ′ = svd_compact!(@view!(A[bI]), usvᴴ, alg.alg) @assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place" end @@ -203,25 +199,24 @@ function MatrixAlgebraKit.svd_compact!( emptycols = setdiff(1:blocksize(A, 2), bcolIs) # needs copyto! instead because size(::LinearAlgebra.I) doesn't work # U[Block(row, col)] = LinearAlgebra.I - # Vt[Block(col, col)] = LinearAlgebra.I + # Vᴴ[Block(col, col)] = LinearAlgebra.I for (row, col) in zip(emptyrows, emptycols) copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I) - copyto!(@view!(Vt[Block(col, col)]), LinearAlgebra.I) + copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I) end - return USVᴴ + return (U, S, Vᴴ) end function MatrixAlgebraKit.svd_full!( - A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm + A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm ) - MatrixAlgebraKit.check_input(svd_full!, A, USVᴴ) - U, S, Vt = USVᴴ + check_input(svd_full!, A, (U, S, Vᴴ)) # do decomposition on each block for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) - usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt[bcol, bcol])) + usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol])) usvᴴ′ = svd_full!(@view!(A[bI]), usvᴴ, alg.alg) @assert usvᴴ === usvᴴ′ "svd_full! might not be in-place" end @@ -237,17 +232,17 @@ function MatrixAlgebraKit.svd_full!( # Vt[Block(col, col)] = LinearAlgebra.I for (row, col) in zip(emptyrows, emptycols) copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I) - copyto!(@view!(Vt[Block(col, col)]), LinearAlgebra.I) + copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I) end # also handle extra rows/cols for i in (length(emptyrows) + 1):length(emptycols) - copyto!(@view!(Vt[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I) + copyto!(@view!(Vᴴ[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I) end bn = blocksize(A, 2) for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows)) copyto!(@view!(U[Block(emptyrows[k], bn + i)]), LinearAlgebra.I) end - return USVᴴ + return (U, S, Vᴴ) end diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 175e78f3..1ad4e72e 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -1,4 +1,4 @@ -using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc! +using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc! function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T} D = BlockSparseVector{T}(undef, axes(A, 1)) @@ -21,18 +21,29 @@ struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: Truncat strategy::T end -const TBlockUSVᴴ = Tuple{ - <:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix -} - function MatrixAlgebraKit.truncate!( - ::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy + ::typeof(svd_trunc!), + (U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix}, + strategy::TruncationStrategy, ) # TODO assert blockdiagonal return MatrixAlgebraKit.truncate!( svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy) ) end +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), + (D, V)::NTuple{2,AbstractBlockSparseMatrix}, + strategy::TruncationStrategy, + ) + return MatrixAlgebraKit.truncate!( + $f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy) + ) + end + end +end # cannot use regular slicing here: I want to slice without altering blockstructure # solution: use boolean indexing and slice the mask, effectively cheaply inverting the map @@ -47,9 +58,21 @@ end function MatrixAlgebraKit.truncate!( ::typeof(svd_trunc!), - (U, S, Vᴴ)::TBlockUSVᴴ, + (U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix}, strategy::BlockPermutedDiagonalTruncationStrategy, ) I = MatrixAlgebraKit.findtruncated(diagview(S), strategy) return (U[:, I], S[I, I], Vᴴ[I, :]) end +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), + (D, V)::NTuple{2,AbstractBlockSparseMatrix}, + strategy::BlockPermutedDiagonalTruncationStrategy, + ) + I = MatrixAlgebraKit.findtruncated(diagview(D), strategy) + return (D[I, I], V[:, I]) + end + end +end diff --git a/test/Project.toml b/test/Project.toml index 66cf5f46..0463a547 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 2ae4ebce..222ae1c1 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,6 +1,14 @@ using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar -using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex +using BlockSparseArrays: + BlockSparseArray, BlockDiagonal, blockstoredlength, eachblockstoredindex using MatrixAlgebraKit: + diagview, + eig_full, + eig_trunc, + eig_vals, + eigh_full, + eigh_trunc, + eigh_vals, left_orth, left_polar, lq_compact, @@ -14,8 +22,9 @@ using MatrixAlgebraKit: svd_trunc, truncrank, trunctol -using LinearAlgebra: LinearAlgebra +using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart using Random: Random +using StableRNGs: StableRNG using Test: @inferred, @testset, @test function test_svd(a, (U, S, Vᴴ); full=false) @@ -100,7 +109,7 @@ end # svd_trunc! # ---------- -@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params +@testset "svd_trunc ($m, $n) BlockSparseMatrix{$T}" for ((m, n), T) in test_params a = BlockSparseArray{T}(undef, m, n) # test blockdiagonal @@ -273,3 +282,111 @@ end @test size(U, 1) ≤ 2 @test Matrix(U * U') ≈ LinearAlgebra.I end + +@testset "eig_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, 2, 2) + A[Block(2, 2)] = randn(rng, T, 3, 3) + + D, V = eig_full(A) + @test size(D) == size(A) + @test size(D) == size(A) + @test blockstoredlength(D) == 2 + @test blockstoredlength(V) == 2 + @test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)]) + @test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)]) + @test A * V ≈ V * D +end + +@testset "eig_vals (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, 2, 2) + A[Block(2, 2)] = randn(rng, T, 3, 3) + + D = eig_vals(A) + @test size(D) == (size(A, 1),) + @test blockstoredlength(D) == 2 + D′ = eig_vals(Matrix(A)) + @test sort(D; by=abs) ≈ sort(D′; by=abs) +end + +@testset "eig_trunc (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + D1 = [1.0, 0.1] + V1 = randn(rng, T, 2, 2) + A1 = V1 * Diagonal(D1) * inv(V1) + D2 = [1.0, 0.5, 0.1] + V2 = randn(rng, T, 3, 3) + A2 = V2 * Diagonal(D2) * inv(V2) + A[Block(1, 1)] = A1 + A[Block(2, 2)] = A2 + + D, V = eig_trunc(A; trunc=(; maxrank=3)) + @test size(D) == (3, 3) + @test size(D) == (3, 3) + @test blockstoredlength(D) == 2 + @test blockstoredlength(V) == 2 + @test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)]) + @test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)]) + @test A * V ≈ V * D + @test sort(diagview(D[Block(1, 1)]); by=abs, rev=true) ≈ D1[1:1] + @test sort(diagview(D[Block(2, 2)]); by=abs, rev=true) ≈ D2[1:2] +end + +herm(x) = parent(hermitianpart(x)) + +@testset "eigh_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + A[Block(1, 1)] = herm(randn(rng, T, 2, 2)) + A[Block(2, 2)] = herm(randn(rng, T, 3, 3)) + + D, V = eigh_full(A) + @test size(D) == size(A) + @test size(D) == size(A) + @test blockstoredlength(D) == 2 + @test blockstoredlength(V) == 2 + @test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)]) + @test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)]) + @test A * V ≈ V * D +end + +@testset "eigh_vals (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + A[Block(1, 1)] = herm(randn(rng, T, 2, 2)) + A[Block(2, 2)] = herm(randn(rng, T, 3, 3)) + + D = eigh_vals(A) + @test size(D) == (size(A, 1),) + @test blockstoredlength(D) == 2 + D′ = eigh_vals(Matrix(A)) + @test sort(D; by=abs) ≈ sort(D′; by=abs) +end + +@testset "eigh_trunc (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + D1 = [1.0, 0.1] + V1, _ = qr_compact(randn(rng, T, 2, 2)) + A1 = V1 * Diagonal(D1) * V1' + D2 = [1.0, 0.5, 0.1] + V2, _ = qr_compact(randn(rng, T, 3, 3)) + A2 = V2 * Diagonal(D2) * V2' + A[Block(1, 1)] = herm(A1) + A[Block(2, 2)] = herm(A2) + + D, V = eigh_trunc(A; trunc=(; maxrank=3)) + @test size(D) == (3, 3) + @test size(D) == (3, 3) + @test blockstoredlength(D) == 2 + @test blockstoredlength(V) == 2 + @test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)]) + @test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)]) + @test A * V ≈ V * D + @test sort(diagview(D[Block(1, 1)]); by=abs, rev=true) ≈ D1[1:1] + @test sort(diagview(D[Block(2, 2)]); by=abs, rev=true) ≈ D2[1:2] +end