From 9acb1570926befd1fa940a038d0f1e16608c8ee2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 30 May 2025 10:12:43 -0400 Subject: [PATCH 1/8] [WIP] Block sparse eig --- Project.toml | 4 +-- src/BlockSparseArrays.jl | 1 + src/factorizations/eig.jl | 56 +++++++++++++++++++++++++++++++++++++ src/factorizations/lq.jl | 20 ++++--------- src/factorizations/polar.jl | 11 -------- src/factorizations/qr.jl | 18 +++--------- src/factorizations/svd.jl | 22 +++------------ 7 files changed, 72 insertions(+), 60 deletions(-) create mode 100644 src/factorizations/eig.jl diff --git a/Project.toml b/Project.toml index c15f9724..02ce898e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.6.8" +version = "0.6.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -39,7 +39,7 @@ GPUArraysCore = "0.1.0, 0.2" LinearAlgebra = "1.10" MacroTools = "0.5.13" MapBroadcast = "0.1.5" -MatrixAlgebraKit = "0.2" +MatrixAlgebraKit = "0.2.2" SparseArraysBase = "0.5" SplitApplyCombine = "1.2.3" TensorAlgebra = "0.3.2" diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index 6470fc61..ec996e06 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -51,5 +51,6 @@ include("factorizations/qr.jl") include("factorizations/lq.jl") include("factorizations/polar.jl") include("factorizations/orthnull.jl") +include("factorizations/eig.jl") end diff --git a/src/factorizations/eig.jl b/src/factorizations/eig.jl new file mode 100644 index 00000000..c384a913 --- /dev/null +++ b/src/factorizations/eig.jl @@ -0,0 +1,56 @@ +using MatrixAlgebraKit: + MatrixAlgebraKit, default_eig_algorithm, default_eigh_algorithm, eig_full!, eigh_full! + +function MatrixAlgebraKit.default_eig_algorithm( + arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs... +) + alg = default_eig_algorithm(blocktype(arrayt); kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) +end + +function MatrixAlgebraKit.initialize_output( + ::typeof(eig_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm +) + D = similar(A, complex(eltype(A))) + V = similar(A, complex(eltype(A))) + return (D, V) +end + +function MatrixAlgebraKit.eig_full!( + A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm +) + for I in blockdiagindices(A) + d, v = eig_full!(A[I], alg.alg) + D[I] = d + V[I] = v + end + return (D, V) +end + +# TODO: this is a hardcoded for now to get around this function not being defined in the +# type domain +function MatrixAlgebraKit.default_eigh_algorithm( + arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs... +) + alg = default_eigh_algorithm(blocktype(arrayt); kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) +end + +function MatrixAlgebraKit.initialize_output( + ::typeof(eigh_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm +) + D = similar(A, complex(eltype(A))) + V = similar(A, complex(eltype(A))) + return (D, V) +end + +function MatrixAlgebraKit.eigh_full!( + A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm +) + for I in blockdiagindices(A) + d, v = eigh_full!(A[I], alg.alg) + D[I] = d + V[I] = v + end + return (D, V) +end diff --git a/src/factorizations/lq.jl b/src/factorizations/lq.jl index 4a07cfa6..d3bebc78 100644 --- a/src/factorizations/lq.jl +++ b/src/factorizations/lq.jl @@ -1,22 +1,12 @@ -using MatrixAlgebraKit: MatrixAlgebraKit, lq_compact!, lq_full! +using MatrixAlgebraKit: MatrixAlgebraKit, default_lq_algorithm, lq_compact!, lq_full! # TODO: this is a hardcoded for now to get around this function not being defined in the # type domain -function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...) - blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} || - error("unsupported type: $(blocktype(A))") - alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) -end -function MatrixAlgebraKit.default_algorithm( - ::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs... +function MatrixAlgebraKit.default_lq_algorithm( + arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs... ) - return default_blocksparse_lq_algorithm(A; kwargs...) -end -function MatrixAlgebraKit.default_algorithm( - ::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs... -) - return default_blocksparse_lq_algorithm(A; kwargs...) + alg = default_lq_algorithm(blocktype(arrayt); kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) end function similar_output( diff --git a/src/factorizations/polar.jl b/src/factorizations/polar.jl index 9b9c2831..f123662f 100644 --- a/src/factorizations/polar.jl +++ b/src/factorizations/polar.jl @@ -46,14 +46,3 @@ function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarV P = U * S * copy(U') return (P, Wᴴ) end - -function MatrixAlgebraKit.default_algorithm( - ::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs... -) - return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...)) -end -function MatrixAlgebraKit.default_algorithm( - ::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs... -) - return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...)) -end diff --git a/src/factorizations/qr.jl b/src/factorizations/qr.jl index 55a0b93e..7a4199ce 100644 --- a/src/factorizations/qr.jl +++ b/src/factorizations/qr.jl @@ -3,21 +3,11 @@ using MatrixAlgebraKit: # TODO: this is a hardcoded for now to get around this function not being defined in the # type domain -function MatrixAlgebraKit.default_qr_algorithm(A::AbstractBlockSparseMatrix; kwargs...) - blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} || - error("unsupported type: $(blocktype(A))") - alg = MatrixAlgebraKit.LAPACK_HouseholderQR(; kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) -end -function MatrixAlgebraKit.default_algorithm( - ::typeof(qr_compact!), A::AbstractBlockSparseMatrix; kwargs... +function MatrixAlgebraKit.default_qr_algorithm( + arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs... ) - return default_qr_algorithm(A; kwargs...) -end -function MatrixAlgebraKit.default_algorithm( - ::typeof(qr_full!), A::AbstractBlockSparseMatrix; kwargs... -) - return default_qr_algorithm(A; kwargs...) + alg = default_qr_algorithm(blocktype(arrayt); kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) end function similar_output( diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index 187795ea..1b15527d 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -12,25 +12,11 @@ struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <: alg::A end -function MatrixAlgebraKit.default_svd_algorithm(A; kwargs...) - blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} || - error("unsupported type: $(blocktype(A))") - # TODO: this is a hardcoded for now to get around this function not being defined in the - # type domain - # alg = MatrixAlgebraKit.default_algorithm(f, blocktype(A); kwargs...) - alg = MatrixAlgebraKit.LAPACK_DivideAndConquer(; kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) -end - -function MatrixAlgebraKit.default_algorithm( - f::typeof(svd_compact!), A::AbstractBlockSparseMatrix; kwargs... +function MatrixAlgebraKit.default_svd_algorithm( + arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs... ) - return default_svd_algorithm(A; kwargs...) -end -function MatrixAlgebraKit.default_algorithm( - f::typeof(svd_full!), A::AbstractBlockSparseMatrix; kwargs... -) - return default_svd_algorithm(A; kwargs...) + alg = default_svd_algorithm(blocktype(arrayt)) + return BlockPermutedDiagonalAlgorithm(alg) end function similar_output( From 707797914b7edaa191c4bd4cd60d7bd0d8e15640 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 30 May 2025 10:27:15 -0400 Subject: [PATCH 2/8] Remove outdated comment --- src/factorizations/eig.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/factorizations/eig.jl b/src/factorizations/eig.jl index c384a913..8d23090f 100644 --- a/src/factorizations/eig.jl +++ b/src/factorizations/eig.jl @@ -27,8 +27,6 @@ function MatrixAlgebraKit.eig_full!( return (D, V) end -# TODO: this is a hardcoded for now to get around this function not being defined in the -# type domain function MatrixAlgebraKit.default_eigh_algorithm( arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs... ) From 8af8e5b8a6a1d82790dbd079ae1c4c4b70b744c5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 30 May 2025 17:33:37 -0400 Subject: [PATCH 3/8] More specific block type, code sharing across Hermitian and non-Hermitian --- .../wrappedabstractblocksparsearray.jl | 6 ++ src/factorizations/eig.jl | 65 ++++++++----------- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 84c24ed8..83b476d6 100644 --- a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -335,6 +335,12 @@ function Base.similar( return @interface BlockSparseArrayInterface() similar(a, elt, axes) end +struct BlockType{T} end +BlockType(x) = BlockType{x}() +function Base.similar(a::AbstractBlockSparseArray, ::BlockType{T}) where {T} + return BlockSparseArray{eltype(T),ndims(T),T}(undef, axes(a)) +end + # TODO: Implement this in a more generic way using a smarter `copyto!`, # which is ultimately what `Array{T,N}(::AbstractArray{<:Any,N})` calls. # These are defined for now to avoid scalar indexing issues when there diff --git a/src/factorizations/eig.jl b/src/factorizations/eig.jl index 8d23090f..cd504ad2 100644 --- a/src/factorizations/eig.jl +++ b/src/factorizations/eig.jl @@ -1,54 +1,45 @@ using MatrixAlgebraKit: MatrixAlgebraKit, default_eig_algorithm, default_eigh_algorithm, eig_full!, eigh_full! -function MatrixAlgebraKit.default_eig_algorithm( - arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs... +function initialize_blocksparse_eig_output( + f, A::AbstractMatrix, alg::BlockPermutedDiagonalAlgorithm ) - alg = default_eig_algorithm(blocktype(arrayt); kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) -end - -function MatrixAlgebraKit.initialize_output( - ::typeof(eig_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm -) - D = similar(A, complex(eltype(A))) - V = similar(A, complex(eltype(A))) + Td, Tv = fieldtypes(Base.promote_op(f, blocktype(A), typeof(alg.alg))) + D = similar(A, BlockType(Td)) + V = similar(A, BlockType(Tv)) return (D, V) end -function MatrixAlgebraKit.eig_full!( - A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm +function blocksparse_eig_full!( + f, A::AbstractMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm ) for I in blockdiagindices(A) - d, v = eig_full!(A[I], alg.alg) - D[I] = d - V[I] = v + d, v = f(@view!(A[I]), alg.alg) + D[I], V[I] = d, v end return (D, V) end -function MatrixAlgebraKit.default_eigh_algorithm( - arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs... -) - alg = default_eigh_algorithm(blocktype(arrayt); kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) -end - -function MatrixAlgebraKit.initialize_output( - ::typeof(eigh_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm -) - D = similar(A, complex(eltype(A))) - V = similar(A, complex(eltype(A))) - return (D, V) +for f in [:default_eig_algorithm, :default_eigh_algorithm] + @eval begin + function MatrixAlgebraKit.$f(arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...) + alg = $f(blocktype(arrayt); kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) + end + end end -function MatrixAlgebraKit.eigh_full!( - A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm -) - for I in blockdiagindices(A) - d, v = eigh_full!(A[I], alg.alg) - D[I] = d - V[I] = v +for f in [:eig_full!, :eigh_full!] + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm + ) + return initialize_blocksparse_eig_output($f, A, alg) + end + function MatrixAlgebraKit.$f( + A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm + ) + return blocksparse_eig_full!($f, A, (D, V), alg) + end end - return (D, V) end From e3bc69a22d775297f9da171f53730de99f41a4c3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 30 May 2025 17:56:34 -0400 Subject: [PATCH 4/8] Define block sparse eig_vals --- .../blocksparsearrayinterface.jl | 3 + src/factorizations/eig.jl | 62 ++++++++++++------- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 0015de0b..c0b557dd 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -45,6 +45,9 @@ end function eachstoredblockdiagindex(a::AbstractArray) return eachblockstoredindex(a) ∩ blockdiagindices(a) end +function eachunstoredblockdiagindex(a::AbstractArray) + return setdiff(blockdiagindices(a), eachblockstoredindex(a)) +end # Like `BlockArrays.eachblock` but only iterating # over stored blocks. diff --git a/src/factorizations/eig.jl b/src/factorizations/eig.jl index cd504ad2..a3ab69d5 100644 --- a/src/factorizations/eig.jl +++ b/src/factorizations/eig.jl @@ -1,24 +1,13 @@ +using BlockArrays: blocksizes +using LinearAlgebra: LinearAlgebra using MatrixAlgebraKit: - MatrixAlgebraKit, default_eig_algorithm, default_eigh_algorithm, eig_full!, eigh_full! - -function initialize_blocksparse_eig_output( - f, A::AbstractMatrix, alg::BlockPermutedDiagonalAlgorithm -) - Td, Tv = fieldtypes(Base.promote_op(f, blocktype(A), typeof(alg.alg))) - D = similar(A, BlockType(Td)) - V = similar(A, BlockType(Tv)) - return (D, V) -end - -function blocksparse_eig_full!( - f, A::AbstractMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm -) - for I in blockdiagindices(A) - d, v = f(@view!(A[I]), alg.alg) - D[I], V[I] = d, v - end - return (D, V) -end + MatrixAlgebraKit, + default_eig_algorithm, + default_eigh_algorithm, + eig_full!, + eig_vals!, + eigh_full!, + eigh_vals! for f in [:default_eig_algorithm, :default_eigh_algorithm] @eval begin @@ -34,12 +23,41 @@ for f in [:eig_full!, :eigh_full!] function MatrixAlgebraKit.initialize_output( ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm ) - return initialize_blocksparse_eig_output($f, A, alg) + Td, Tv = fieldtypes(Base.promote_op($f, blocktype(A), typeof(alg.alg))) + D = similar(A, BlockType(Td)) + V = similar(A, BlockType(Tv)) + return (D, V) end function MatrixAlgebraKit.$f( A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm ) - return blocksparse_eig_full!($f, A, (D, V), alg) + for I in eachstoredblockdiagindex(A) + D[I], V[I] = $f(@view(A[I]), alg.alg) + end + for I in eachunstoredblockdiagindex(A) + # TODO: Support setting `LinearAlgebra.I` directly, and/or + # using `FillArrays.Eye`. + V[I] = LinearAlgebra.I(first(blocksizes(A)[Int.(Tuple(I))...])) + end + return (D, V) + end + end +end + +for f in [:eig_vals!, :eigh_vals!] + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm + ) + return similar(A, axes(A, 1)) + end + function MatrixAlgebraKit.$f( + A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm + ) + for I in eachblockstoredindex(A) + D[I] = $f(@view!(A[I]), alg.alg) + end + return D end end end From 36d9fba28183abb6e16d28b6ce1630c55ace12c0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 1 Jun 2025 10:30:18 -0400 Subject: [PATCH 5/8] Start implementing truncation --- src/factorizations/eig.jl | 33 +++++++++++++++++++++++++++++--- src/factorizations/truncation.jl | 2 +- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/factorizations/eig.jl b/src/factorizations/eig.jl index a3ab69d5..1be93a90 100644 --- a/src/factorizations/eig.jl +++ b/src/factorizations/eig.jl @@ -1,13 +1,19 @@ using BlockArrays: blocksizes -using LinearAlgebra: LinearAlgebra +using DiagonalArrays: diagonal +using LinearAlgebra: LinearAlgebra, Diagonal using MatrixAlgebraKit: MatrixAlgebraKit, + TruncationStrategy, default_eig_algorithm, default_eigh_algorithm, + diagview, eig_full!, + eig_trunc!, eig_vals!, eigh_full!, - eigh_vals! + eigh_trunc!, + eigh_vals!, + findtruncated for f in [:default_eig_algorithm, :default_eigh_algorithm] @eval begin @@ -37,7 +43,7 @@ for f in [:eig_full!, :eigh_full!] for I in eachunstoredblockdiagindex(A) # TODO: Support setting `LinearAlgebra.I` directly, and/or # using `FillArrays.Eye`. - V[I] = LinearAlgebra.I(first(blocksizes(A)[Int.(Tuple(I))...])) + V[I] = LinearAlgebra.I(size(@view(V[I]), 1)) end return (D, V) end @@ -61,3 +67,24 @@ for f in [:eig_vals!, :eigh_vals!] end end end + +const TBlockDV = Tuple{AbstractBlockSparseMatrix,AbstractBlockSparseMatrix} + +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), (D, V)::TBlockDV, strategy::TruncationStrategy + ) + return MatrixAlgebraKit.truncate!( + $f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy) + ) + end + function MatrixAlgebraKit.truncate!( + ::typeof($f), (D, V)::TBlockDV, strategy::BlockPermutedDiagonalTruncationStrategy + ) + d = diagview(D) + ind = findtruncated(d, strategy) + return diagonal(d[ind]), V[:, ind] + end + end +end diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index e3362128..2f093e05 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -22,7 +22,7 @@ struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: Truncat end const TBlockUSVᴴ = Tuple{ - <:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix + AbstractBlockSparseMatrix,AbstractBlockSparseMatrix,AbstractBlockSparseMatrix } function MatrixAlgebraKit.truncate!( From 81062b5eca8002c35913db2652a62fbebc023511 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 5 Jun 2025 09:14:11 -0400 Subject: [PATCH 6/8] Typo fix in tests --- test/test_factorizations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 2ae4ebce..df4fe0c0 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -100,7 +100,7 @@ end # svd_trunc! # ---------- -@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params +@testset "svd_trunc ($m, $n) BlockSparseMatrix{$T}" for ((m, n), T) in test_params a = BlockSparseArray{T}(undef, m, n) # test blockdiagonal From 5146831c6dd76dbb2159585fc4a32a82c11a3e45 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 5 Jun 2025 10:58:12 -0400 Subject: [PATCH 7/8] More tests, small fixes --- .../wrappedabstractblocksparsearray.jl | 7 +- src/factorizations/eig.jl | 24 +--- src/factorizations/truncation.jl | 37 +++++- test/Project.toml | 1 + test/test_factorizations.jl | 121 +++++++++++++++++- 5 files changed, 157 insertions(+), 33 deletions(-) diff --git a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index d1a392ce..b72f1e41 100644 --- a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -350,8 +350,11 @@ end struct BlockType{T} end BlockType(x) = BlockType{x}() -function Base.similar(a::AbstractBlockSparseArray, ::BlockType{T}) where {T} - return BlockSparseArray{eltype(T),ndims(T),T}(undef, axes(a)) +function Base.similar(a::AbstractBlockSparseArray, ::BlockType{T}, ax) where {T} + return BlockSparseArray{eltype(T),ndims(T),T}(undef, ax) +end +function Base.similar(a::AbstractBlockSparseArray, T::BlockType) + return similar(a, T, axes(a)) end # TODO: Implement this in a more generic way using a smarter `copyto!`, diff --git a/src/factorizations/eig.jl b/src/factorizations/eig.jl index 1be93a90..6c50bea0 100644 --- a/src/factorizations/eig.jl +++ b/src/factorizations/eig.jl @@ -55,7 +55,8 @@ for f in [:eig_vals!, :eigh_vals!] function MatrixAlgebraKit.initialize_output( ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm ) - return similar(A, axes(A, 1)) + T = Base.promote_op($f, blocktype(A), typeof(alg.alg)) + return similar(A, BlockType(T), axes(A, 1)) end function MatrixAlgebraKit.$f( A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm @@ -67,24 +68,3 @@ for f in [:eig_vals!, :eigh_vals!] end end end - -const TBlockDV = Tuple{AbstractBlockSparseMatrix,AbstractBlockSparseMatrix} - -for f in [:eig_trunc!, :eigh_trunc!] - @eval begin - function MatrixAlgebraKit.truncate!( - ::typeof($f), (D, V)::TBlockDV, strategy::TruncationStrategy - ) - return MatrixAlgebraKit.truncate!( - $f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy) - ) - end - function MatrixAlgebraKit.truncate!( - ::typeof($f), (D, V)::TBlockDV, strategy::BlockPermutedDiagonalTruncationStrategy - ) - d = diagview(D) - ind = findtruncated(d, strategy) - return diagonal(d[ind]), V[:, ind] - end - end -end diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index b0a51404..1ad4e72e 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -1,4 +1,4 @@ -using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc! +using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc! function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T} D = BlockSparseVector{T}(undef, axes(A, 1)) @@ -21,18 +21,29 @@ struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: Truncat strategy::T end -const TBlockUSVᴴ = Tuple{ - AbstractBlockSparseMatrix,AbstractBlockSparseMatrix,AbstractBlockSparseMatrix -} - function MatrixAlgebraKit.truncate!( - ::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy + ::typeof(svd_trunc!), + (U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix}, + strategy::TruncationStrategy, ) # TODO assert blockdiagonal return MatrixAlgebraKit.truncate!( svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy) ) end +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), + (D, V)::NTuple{2,AbstractBlockSparseMatrix}, + strategy::TruncationStrategy, + ) + return MatrixAlgebraKit.truncate!( + $f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy) + ) + end + end +end # cannot use regular slicing here: I want to slice without altering blockstructure # solution: use boolean indexing and slice the mask, effectively cheaply inverting the map @@ -47,9 +58,21 @@ end function MatrixAlgebraKit.truncate!( ::typeof(svd_trunc!), - (U, S, Vᴴ)::TBlockUSVᴴ, + (U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix}, strategy::BlockPermutedDiagonalTruncationStrategy, ) I = MatrixAlgebraKit.findtruncated(diagview(S), strategy) return (U[:, I], S[I, I], Vᴴ[I, :]) end +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), + (D, V)::NTuple{2,AbstractBlockSparseMatrix}, + strategy::BlockPermutedDiagonalTruncationStrategy, + ) + I = MatrixAlgebraKit.findtruncated(diagview(D), strategy) + return (D[I, I], V[:, I]) + end + end +end diff --git a/test/Project.toml b/test/Project.toml index 66cf5f46..0463a547 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index df4fe0c0..222ae1c1 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,6 +1,14 @@ using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar -using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex +using BlockSparseArrays: + BlockSparseArray, BlockDiagonal, blockstoredlength, eachblockstoredindex using MatrixAlgebraKit: + diagview, + eig_full, + eig_trunc, + eig_vals, + eigh_full, + eigh_trunc, + eigh_vals, left_orth, left_polar, lq_compact, @@ -14,8 +22,9 @@ using MatrixAlgebraKit: svd_trunc, truncrank, trunctol -using LinearAlgebra: LinearAlgebra +using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart using Random: Random +using StableRNGs: StableRNG using Test: @inferred, @testset, @test function test_svd(a, (U, S, Vᴴ); full=false) @@ -273,3 +282,111 @@ end @test size(U, 1) ≤ 2 @test Matrix(U * U') ≈ LinearAlgebra.I end + +@testset "eig_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, 2, 2) + A[Block(2, 2)] = randn(rng, T, 3, 3) + + D, V = eig_full(A) + @test size(D) == size(A) + @test size(D) == size(A) + @test blockstoredlength(D) == 2 + @test blockstoredlength(V) == 2 + @test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)]) + @test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)]) + @test A * V ≈ V * D +end + +@testset "eig_vals (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, 2, 2) + A[Block(2, 2)] = randn(rng, T, 3, 3) + + D = eig_vals(A) + @test size(D) == (size(A, 1),) + @test blockstoredlength(D) == 2 + D′ = eig_vals(Matrix(A)) + @test sort(D; by=abs) ≈ sort(D′; by=abs) +end + +@testset "eig_trunc (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + D1 = [1.0, 0.1] + V1 = randn(rng, T, 2, 2) + A1 = V1 * Diagonal(D1) * inv(V1) + D2 = [1.0, 0.5, 0.1] + V2 = randn(rng, T, 3, 3) + A2 = V2 * Diagonal(D2) * inv(V2) + A[Block(1, 1)] = A1 + A[Block(2, 2)] = A2 + + D, V = eig_trunc(A; trunc=(; maxrank=3)) + @test size(D) == (3, 3) + @test size(D) == (3, 3) + @test blockstoredlength(D) == 2 + @test blockstoredlength(V) == 2 + @test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)]) + @test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)]) + @test A * V ≈ V * D + @test sort(diagview(D[Block(1, 1)]); by=abs, rev=true) ≈ D1[1:1] + @test sort(diagview(D[Block(2, 2)]); by=abs, rev=true) ≈ D2[1:2] +end + +herm(x) = parent(hermitianpart(x)) + +@testset "eigh_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + A[Block(1, 1)] = herm(randn(rng, T, 2, 2)) + A[Block(2, 2)] = herm(randn(rng, T, 3, 3)) + + D, V = eigh_full(A) + @test size(D) == size(A) + @test size(D) == size(A) + @test blockstoredlength(D) == 2 + @test blockstoredlength(V) == 2 + @test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)]) + @test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)]) + @test A * V ≈ V * D +end + +@testset "eigh_vals (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + A[Block(1, 1)] = herm(randn(rng, T, 2, 2)) + A[Block(2, 2)] = herm(randn(rng, T, 3, 3)) + + D = eigh_vals(A) + @test size(D) == (size(A, 1),) + @test blockstoredlength(D) == 2 + D′ = eigh_vals(Matrix(A)) + @test sort(D; by=abs) ≈ sort(D′; by=abs) +end + +@testset "eigh_trunc (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [2, 3])) + rng = StableRNG(123) + D1 = [1.0, 0.1] + V1, _ = qr_compact(randn(rng, T, 2, 2)) + A1 = V1 * Diagonal(D1) * V1' + D2 = [1.0, 0.5, 0.1] + V2, _ = qr_compact(randn(rng, T, 3, 3)) + A2 = V2 * Diagonal(D2) * V2' + A[Block(1, 1)] = herm(A1) + A[Block(2, 2)] = herm(A2) + + D, V = eigh_trunc(A; trunc=(; maxrank=3)) + @test size(D) == (3, 3) + @test size(D) == (3, 3) + @test blockstoredlength(D) == 2 + @test blockstoredlength(V) == 2 + @test issetequal(eachblockstoredindex(D), [Block(1, 1), Block(2, 2)]) + @test issetequal(eachblockstoredindex(V), [Block(1, 1), Block(2, 2)]) + @test A * V ≈ V * D + @test sort(diagview(D[Block(1, 1)]); by=abs, rev=true) ≈ D1[1:1] + @test sort(diagview(D[Block(2, 2)]); by=abs, rev=true) ≈ D2[1:2] +end From 4df87f444b180f7db1a253b3f69d56a2e48f789a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 5 Jun 2025 11:14:18 -0400 Subject: [PATCH 8/8] Check inputs --- src/factorizations/eig.jl | 22 ++++++++++++++++++ src/factorizations/svd.jl | 49 ++++++++++++++++++--------------------- 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/src/factorizations/eig.jl b/src/factorizations/eig.jl index 6c50bea0..82358f48 100644 --- a/src/factorizations/eig.jl +++ b/src/factorizations/eig.jl @@ -4,6 +4,7 @@ using LinearAlgebra: LinearAlgebra, Diagonal using MatrixAlgebraKit: MatrixAlgebraKit, TruncationStrategy, + check_input, default_eig_algorithm, default_eigh_algorithm, diagview, @@ -24,6 +25,26 @@ for f in [:default_eig_algorithm, :default_eigh_algorithm] end end +function MatrixAlgebraKit.check_input( + ::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V) +) + @assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix) + @assert eltype(V) === eltype(D) === complex(eltype(A)) + @assert axes(A, 1) == axes(A, 2) + @assert axes(A) == axes(D) == axes(V) + return nothing +end +function MatrixAlgebraKit.check_input( + ::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V) +) + @assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix) + @assert eltype(V) === eltype(A) + @assert eltype(D) === real(eltype(A)) + @assert axes(A, 1) == axes(A, 2) + @assert axes(A) == axes(D) == axes(V) + return nothing +end + for f in [:eig_full!, :eigh_full!] @eval begin function MatrixAlgebraKit.initialize_output( @@ -37,6 +58,7 @@ for f in [:eig_full!, :eigh_full!] function MatrixAlgebraKit.$f( A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm ) + check_input($f, A, (D, V)) for I in eachstoredblockdiagindex(A) D[I], V[I] = $f(@view(A[I]), alg.alg) end diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index 673e5763..1f8f4a42 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -1,4 +1,5 @@ -using MatrixAlgebraKit: MatrixAlgebraKit, default_svd_algorithm, svd_compact!, svd_full! +using MatrixAlgebraKit: + MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full! """ BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm) @@ -152,45 +153,40 @@ function MatrixAlgebraKit.initialize_output( end function MatrixAlgebraKit.check_input( - ::typeof(svd_compact!), A::AbstractBlockSparseMatrix, USVᴴ + ::typeof(svd_compact!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ) ) - U, S, Vt = USVᴴ @assert isa(U, AbstractBlockSparseMatrix) && isa(S, AbstractBlockSparseMatrix) && - isa(Vt, AbstractBlockSparseMatrix) - @assert eltype(A) == eltype(U) == eltype(Vt) + isa(Vᴴ, AbstractBlockSparseMatrix) + @assert eltype(A) == eltype(U) == eltype(Vᴴ) @assert real(eltype(A)) == eltype(S) - @assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vt, 2) + @assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 2) @assert axes(S, 1) == axes(S, 2) - return nothing end function MatrixAlgebraKit.check_input( - ::typeof(svd_full!), A::AbstractBlockSparseMatrix, USVᴴ + ::typeof(svd_full!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ) ) - U, S, Vt = USVᴴ @assert isa(U, AbstractBlockSparseMatrix) && isa(S, AbstractBlockSparseMatrix) && - isa(Vt, AbstractBlockSparseMatrix) - @assert eltype(A) == eltype(U) == eltype(Vt) + isa(Vᴴ, AbstractBlockSparseMatrix) + @assert eltype(A) == eltype(U) == eltype(Vᴴ) @assert real(eltype(A)) == eltype(S) - @assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vt, 1) == axes(Vt, 2) + @assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 1) == axes(Vᴴ, 2) @assert axes(S, 2) == axes(A, 2) - return nothing end function MatrixAlgebraKit.svd_compact!( - A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm + A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm ) - MatrixAlgebraKit.check_input(svd_compact!, A, USVᴴ) - U, S, Vt = USVᴴ + check_input(svd_compact!, A, (U, S, Vᴴ)) # do decomposition on each block for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) - usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt[bcol, bcol])) + usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol])) usvᴴ′ = svd_compact!(@view!(A[bI]), usvᴴ, alg.alg) @assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place" end @@ -203,25 +199,24 @@ function MatrixAlgebraKit.svd_compact!( emptycols = setdiff(1:blocksize(A, 2), bcolIs) # needs copyto! instead because size(::LinearAlgebra.I) doesn't work # U[Block(row, col)] = LinearAlgebra.I - # Vt[Block(col, col)] = LinearAlgebra.I + # Vᴴ[Block(col, col)] = LinearAlgebra.I for (row, col) in zip(emptyrows, emptycols) copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I) - copyto!(@view!(Vt[Block(col, col)]), LinearAlgebra.I) + copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I) end - return USVᴴ + return (U, S, Vᴴ) end function MatrixAlgebraKit.svd_full!( - A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm + A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm ) - MatrixAlgebraKit.check_input(svd_full!, A, USVᴴ) - U, S, Vt = USVᴴ + check_input(svd_full!, A, (U, S, Vᴴ)) # do decomposition on each block for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) - usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt[bcol, bcol])) + usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol])) usvᴴ′ = svd_full!(@view!(A[bI]), usvᴴ, alg.alg) @assert usvᴴ === usvᴴ′ "svd_full! might not be in-place" end @@ -237,17 +232,17 @@ function MatrixAlgebraKit.svd_full!( # Vt[Block(col, col)] = LinearAlgebra.I for (row, col) in zip(emptyrows, emptycols) copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I) - copyto!(@view!(Vt[Block(col, col)]), LinearAlgebra.I) + copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I) end # also handle extra rows/cols for i in (length(emptyrows) + 1):length(emptycols) - copyto!(@view!(Vt[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I) + copyto!(@view!(Vᴴ[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I) end bn = blocksize(A, 2) for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows)) copyto!(@view!(U[Block(emptyrows[k], bn + i)]), LinearAlgebra.I) end - return USVᴴ + return (U, S, Vᴴ) end