diff --git a/Project.toml b/Project.toml index 72d74f5..720a92b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.12" +version = "0.7.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index ec996e0..4229343 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -45,6 +45,7 @@ include("blocksparsearray/blockdiagonalarray.jl") include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl") # factorizations +include("factorizations/tensorproducts.jl") include("factorizations/svd.jl") include("factorizations/truncation.jl") include("factorizations/qr.jl") diff --git a/src/blocksparsearrayinterface/getunstoredblock.jl b/src/blocksparsearrayinterface/getunstoredblock.jl index 36cadde..d221286 100644 --- a/src/blocksparsearrayinterface/getunstoredblock.jl +++ b/src/blocksparsearrayinterface/getunstoredblock.jl @@ -6,16 +6,17 @@ struct GetUnstoredBlock{Axes} end @inline function (f::GetUnstoredBlock)( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - # TODO: Make sure this works for sparse or block sparse blocks, immutable - # blocks, diagonal blocks, etc.! - b_ax = ntuple(ndims(a)) do d + ::Type{<:AbstractArray{A,N}}, I::Vararg{Int,N} +) where {A,N} + ax = ntuple(N) do d return only(axes(f.axes[d][Block(I[d])])) end - b = similar(eltype(a), b_ax) - zero!(b) - return b + return zero!(similar(A, ax)) +end +@inline function (f::GetUnstoredBlock)( + a::AbstractArray{<:Any,N}, I::Vararg{Int,N} +) where {N} + return f(typeof(a), I...) end # TODO: Use `Base.to_indices`. @inline function (f::GetUnstoredBlock)( @@ -25,16 +26,17 @@ end end # TODO: this is a hack and is also type-unstable +using LinearAlgebra: Diagonal +using TypeParameterAccessors: similartype function (f::GetUnstoredBlock)( - a::AbstractMatrix{LinearAlgebra.Diagonal{T,V}}, I::Vararg{Int,2} -) where {T,V} - b_size = ntuple(ndims(a)) do d - return length(f.axes[d][Block(I[d])]) + ::Type{<:AbstractMatrix{<:Diagonal{<:Any,V}}}, I::Vararg{Int,2} +) where {V} + ax = ntuple(2) do d + return only(axes(f.axes[d][Block(I[d])])) end - if I[1] == I[2] - diag = zero!(similar(V, b_size[1])) - return LinearAlgebra.Diagonal{T,V}(diag) + if allequal(I) + return Diagonal(zero!(similar(V, first(ax)))) else - return zeros(T, b_size...) + return zero!(similar(similartype(V, typeof(ax)), ax)) end end diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index be7ca75..b57b8eb 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -25,10 +25,11 @@ end function similar_output( ::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm ) - U = similar(A, axes(A, 1), S_axes[1]) - S = similar(A, BlockType(diagonaltype(realtype(blocktype(A)))), S_axes) - Vt = similar(A, S_axes[2], axes(A, 2)) - return U, S, Vt + BU, BS, BVᴴ = fieldtypes(Base.promote_op(svd_compact!, blocktype(A), typeof(alg.alg))) + U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1])) + S = similar(A, BlockType(BS), S_axes) + Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2))) + return U, S, Vᴴ end function MatrixAlgebraKit.initialize_output( @@ -48,9 +49,8 @@ function MatrixAlgebraKit.initialize_output( bcolIs = Int.(last.(Tuple.(bIs))) for bI in eachblockstoredindex(A) row, col = Int.(Tuple(bI)) - len = minimum(length, (brows[row], bcols[col])) - u_axes[col] = brows[row][Base.OneTo(len)] - v_axes[col] = bcols[col][Base.OneTo(len)] + u_axes[col] = infimum(brows[row], bcols[col]) + v_axes[col] = infimum(bcols[col], brows[row]) end # fill in values for blocks that aren't present, pairing them in order of occurence @@ -58,9 +58,8 @@ function MatrixAlgebraKit.initialize_output( emptyrows = setdiff(1:bm, browIs) emptycols = setdiff(1:bn, bcolIs) for (row, col) in zip(emptyrows, emptycols) - len = minimum(length, (brows[row], bcols[col])) - u_axes[col] = brows[row][Base.OneTo(len)] - v_axes[col] = bcols[col][Base.OneTo(len)] + u_axes[col] = infimum(brows[row], bcols[col]) + v_axes[col] = infimum(bcols[col], brows[row]) end u_axis = mortar_axis(u_axes) diff --git a/src/factorizations/tensorproducts.jl b/src/factorizations/tensorproducts.jl new file mode 100644 index 0000000..bbf02ae --- /dev/null +++ b/src/factorizations/tensorproducts.jl @@ -0,0 +1,19 @@ +function infimum(r1::AbstractUnitRange, r2::AbstractUnitRange) + (isone(first(r1)) && isone(first(r2))) || + throw(ArgumentError("infimum only defined for ranges starting at 1")) + if length(r1) ≤ length(r2) + return r1 + else + return r1[r2] + end +end + +function supremum(r1::AbstractUnitRange, r2::AbstractUnitRange) + (isone(first(r1)) && isone(first(r2))) || + throw(ArgumentError("supremum only defined for ranges starting at 1")) + if length(r1) ≥ length(r2) + return r1 + else + return r2 + end +end