From de3e859c257bbbdf27d1f37faf02e488d9335a49 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 16 Jun 2025 13:25:11 -0400 Subject: [PATCH 1/2] Make block sparse SVD more generic --- Project.toml | 2 +- src/BlockSparseArrays.jl | 1 + .../getunstoredblock.jl | 48 ++++++++++++------- src/factorizations/svd.jl | 19 ++++---- src/factorizations/tensorproducts.jl | 19 ++++++++ 5 files changed, 60 insertions(+), 29 deletions(-) create mode 100644 src/factorizations/tensorproducts.jl diff --git a/Project.toml b/Project.toml index 72d74f54..720a92be 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.12" +version = "0.7.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index ec996e06..42293432 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -45,6 +45,7 @@ include("blocksparsearray/blockdiagonalarray.jl") include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl") # factorizations +include("factorizations/tensorproducts.jl") include("factorizations/svd.jl") include("factorizations/truncation.jl") include("factorizations/qr.jl") diff --git a/src/blocksparsearrayinterface/getunstoredblock.jl b/src/blocksparsearrayinterface/getunstoredblock.jl index 36caddea..cf424304 100644 --- a/src/blocksparsearrayinterface/getunstoredblock.jl +++ b/src/blocksparsearrayinterface/getunstoredblock.jl @@ -5,6 +5,35 @@ struct GetUnstoredBlock{Axes} axes::Axes end +# Allow customizing based on the block index. +function unstored_block( + A::Type{<:AbstractArray{<:Any,N}}, ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N} +) where {N} + return unstored_block(A, ax) +end +function unstored_block( + A::Type{<:AbstractArray{<:Any,N}}, ax::NTuple{N,AbstractUnitRange{<:Integer}} +) where {N} + a = similar(A, ax) + zero!(a) + return a +end + +using LinearAlgebra: Diagonal +# TODO: This is a hack and is also type-unstable. +function unstored_block( + A::Type{<:Diagonal{<:Any,V}}, ax::NTuple{2,AbstractUnitRange{<:Integer}}, I::Block{2} +) where {V} + if allequal(Tuple(I)) + # Diagonal blocks. + diag = zero!(similar(V, first(ax))) + return Diagonal(diag) + else + # Off-diagonal blocks. + return zero!(similar(similartype(V, typeof(ax)), ax)) + end +end + @inline function (f::GetUnstoredBlock)( a::AbstractArray{<:Any,N}, I::Vararg{Int,N} ) where {N} @@ -13,9 +42,7 @@ end b_ax = ntuple(ndims(a)) do d return only(axes(f.axes[d][Block(I[d])])) end - b = similar(eltype(a), b_ax) - zero!(b) - return b + return unstored_block(eltype(a), b_ax, Block(I)) end # TODO: Use `Base.to_indices`. @inline function (f::GetUnstoredBlock)( @@ -23,18 +50,3 @@ end ) where {N} return f(a, Tuple(I)...) end - -# TODO: this is a hack and is also type-unstable -function (f::GetUnstoredBlock)( - a::AbstractMatrix{LinearAlgebra.Diagonal{T,V}}, I::Vararg{Int,2} -) where {T,V} - b_size = ntuple(ndims(a)) do d - return length(f.axes[d][Block(I[d])]) - end - if I[1] == I[2] - diag = zero!(similar(V, b_size[1])) - return LinearAlgebra.Diagonal{T,V}(diag) - else - return zeros(T, b_size...) - end -end diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index be7ca75c..b57b8eb1 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -25,10 +25,11 @@ end function similar_output( ::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm ) - U = similar(A, axes(A, 1), S_axes[1]) - S = similar(A, BlockType(diagonaltype(realtype(blocktype(A)))), S_axes) - Vt = similar(A, S_axes[2], axes(A, 2)) - return U, S, Vt + BU, BS, BVᴴ = fieldtypes(Base.promote_op(svd_compact!, blocktype(A), typeof(alg.alg))) + U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1])) + S = similar(A, BlockType(BS), S_axes) + Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2))) + return U, S, Vᴴ end function MatrixAlgebraKit.initialize_output( @@ -48,9 +49,8 @@ function MatrixAlgebraKit.initialize_output( bcolIs = Int.(last.(Tuple.(bIs))) for bI in eachblockstoredindex(A) row, col = Int.(Tuple(bI)) - len = minimum(length, (brows[row], bcols[col])) - u_axes[col] = brows[row][Base.OneTo(len)] - v_axes[col] = bcols[col][Base.OneTo(len)] + u_axes[col] = infimum(brows[row], bcols[col]) + v_axes[col] = infimum(bcols[col], brows[row]) end # fill in values for blocks that aren't present, pairing them in order of occurence @@ -58,9 +58,8 @@ function MatrixAlgebraKit.initialize_output( emptyrows = setdiff(1:bm, browIs) emptycols = setdiff(1:bn, bcolIs) for (row, col) in zip(emptyrows, emptycols) - len = minimum(length, (brows[row], bcols[col])) - u_axes[col] = brows[row][Base.OneTo(len)] - v_axes[col] = bcols[col][Base.OneTo(len)] + u_axes[col] = infimum(brows[row], bcols[col]) + v_axes[col] = infimum(bcols[col], brows[row]) end u_axis = mortar_axis(u_axes) diff --git a/src/factorizations/tensorproducts.jl b/src/factorizations/tensorproducts.jl new file mode 100644 index 00000000..bbf02aeb --- /dev/null +++ b/src/factorizations/tensorproducts.jl @@ -0,0 +1,19 @@ +function infimum(r1::AbstractUnitRange, r2::AbstractUnitRange) + (isone(first(r1)) && isone(first(r2))) || + throw(ArgumentError("infimum only defined for ranges starting at 1")) + if length(r1) ≤ length(r2) + return r1 + else + return r1[r2] + end +end + +function supremum(r1::AbstractUnitRange, r2::AbstractUnitRange) + (isone(first(r1)) && isone(first(r2))) || + throw(ArgumentError("supremum only defined for ranges starting at 1")) + if length(r1) ≥ length(r2) + return r1 + else + return r2 + end +end From 15f1f230285035c675d88b17937e604f973b76f0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 16 Jun 2025 15:15:42 -0400 Subject: [PATCH 2/2] Revert some changes to getunstoredblock --- .../getunstoredblock.jl | 56 ++++++++----------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/src/blocksparsearrayinterface/getunstoredblock.jl b/src/blocksparsearrayinterface/getunstoredblock.jl index cf424304..d221286e 100644 --- a/src/blocksparsearrayinterface/getunstoredblock.jl +++ b/src/blocksparsearrayinterface/getunstoredblock.jl @@ -5,44 +5,18 @@ struct GetUnstoredBlock{Axes} axes::Axes end -# Allow customizing based on the block index. -function unstored_block( - A::Type{<:AbstractArray{<:Any,N}}, ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N} -) where {N} - return unstored_block(A, ax) -end -function unstored_block( - A::Type{<:AbstractArray{<:Any,N}}, ax::NTuple{N,AbstractUnitRange{<:Integer}} -) where {N} - a = similar(A, ax) - zero!(a) - return a -end - -using LinearAlgebra: Diagonal -# TODO: This is a hack and is also type-unstable. -function unstored_block( - A::Type{<:Diagonal{<:Any,V}}, ax::NTuple{2,AbstractUnitRange{<:Integer}}, I::Block{2} -) where {V} - if allequal(Tuple(I)) - # Diagonal blocks. - diag = zero!(similar(V, first(ax))) - return Diagonal(diag) - else - # Off-diagonal blocks. - return zero!(similar(similartype(V, typeof(ax)), ax)) +@inline function (f::GetUnstoredBlock)( + ::Type{<:AbstractArray{A,N}}, I::Vararg{Int,N} +) where {A,N} + ax = ntuple(N) do d + return only(axes(f.axes[d][Block(I[d])])) end + return zero!(similar(A, ax)) end - @inline function (f::GetUnstoredBlock)( a::AbstractArray{<:Any,N}, I::Vararg{Int,N} ) where {N} - # TODO: Make sure this works for sparse or block sparse blocks, immutable - # blocks, diagonal blocks, etc.! - b_ax = ntuple(ndims(a)) do d - return only(axes(f.axes[d][Block(I[d])])) - end - return unstored_block(eltype(a), b_ax, Block(I)) + return f(typeof(a), I...) end # TODO: Use `Base.to_indices`. @inline function (f::GetUnstoredBlock)( @@ -50,3 +24,19 @@ end ) where {N} return f(a, Tuple(I)...) end + +# TODO: this is a hack and is also type-unstable +using LinearAlgebra: Diagonal +using TypeParameterAccessors: similartype +function (f::GetUnstoredBlock)( + ::Type{<:AbstractMatrix{<:Diagonal{<:Any,V}}}, I::Vararg{Int,2} +) where {V} + ax = ntuple(2) do d + return only(axes(f.axes[d][Block(I[d])])) + end + if allequal(I) + return Diagonal(zero!(similar(V, first(ax)))) + else + return zero!(similar(similartype(V, typeof(ax)), ax)) + end +end