diff --git a/Project.toml b/Project.toml index d5200c3b..4cf9329a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.5.3" +version = "0.5.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockArraysExtensions/blockedunitrange.jl b/src/BlockArraysExtensions/blockedunitrange.jl index b6edef21..f685b761 100644 --- a/src/BlockArraysExtensions/blockedunitrange.jl +++ b/src/BlockArraysExtensions/blockedunitrange.jl @@ -9,11 +9,32 @@ using BlockArrays: BlockSlice, BlockVector, block, + blockedrange, blockindex, + blocklengths, findblock, findblockindex, mortar +# Get the axes of each block of a block array. +function eachblockaxes(a::AbstractArray) + return map(axes, blocks(a)) +end + +axis(a::AbstractVector) = axes(a, 1) + +# Get the axis of each block of a blocked unit +# range. +function eachblockaxis(a::AbstractVector) + return map(axis, blocks(a)) +end + +# Take a collection of axes and mortar them +# into a single blocked axis. +function mortar_axis(axs) + return blockedrange(length.(axs)) +end + # Custom `BlockedUnitRange` constructor that takes a unit range # and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`. function blockedunitrange(a::AbstractUnitRange, blocklengths) diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index f83c11b3..04d187b2 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -21,11 +21,19 @@ function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kw return BlockPermutedDiagonalAlgorithm(alg) end -# TODO: this should be replaced with a more general similar function that can handle setting -# the blocktype and element type - something like S = similar(A, BlockType(...)) -function _similar_S(A::AbstractBlockSparseMatrix, s_axis) +function similar_output( + ::typeof(svd_compact!), + A, + s_axis::AbstractUnitRange, + alg::MatrixAlgebraKit.AbstractAlgorithm, +) + U = similar(A, axes(A, 1), s_axis) T = real(eltype(A)) - return BlockSparseArray{T,2,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis)) + # TODO: this should be replaced with a more general similar function that can handle setting + # the blocktype and element type - something like S = similar(A, BlockType(...)) + S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis)) + Vt = similar(A, s_axis, axes(A, 2)) + return U, S, Vt end function MatrixAlgebraKit.initialize_output( @@ -34,9 +42,9 @@ function MatrixAlgebraKit.initialize_output( bm, bn = blocksize(A) bmn = min(bm, bn) - brows = blocklengths(axes(A, 1)) - bcols = blocklengths(axes(A, 2)) - slengths = Vector{Int}(undef, bmn) + brows = eachblockaxis(axes(A, 1)) + bcols = eachblockaxis(axes(A, 2)) + s_axes = similar(brows, bmn) # fill in values for blocks that are present bIs = collect(eachblockstoredindex(A)) @@ -44,9 +52,7 @@ function MatrixAlgebraKit.initialize_output( bcolIs = Int.(last.(Tuple.(bIs))) for bI in eachblockstoredindex(A) row, col = Int.(Tuple(bI)) - nrows = brows[row] - ncols = bcols[col] - slengths[col] = min(nrows, ncols) + s_axes[col] = argmin(length, (brows[row], bcols[col])) end # fill in values for blocks that aren't present, pairing them in order of occurence @@ -54,13 +60,11 @@ function MatrixAlgebraKit.initialize_output( emptyrows = setdiff(1:bm, browIs) emptycols = setdiff(1:bn, bcolIs) for (row, col) in zip(emptyrows, emptycols) - slengths[col] = min(brows[row], bcols[col]) + s_axes[col] = argmin(length, (brows[row], bcols[col])) end - s_axis = blockedrange(slengths) - U = similar(A, axes(A, 1), s_axis) - S = _similar_S(A, s_axis) - Vt = similar(A, s_axis, axes(A, 2)) + s_axis = mortar_axis(s_axes) + U, S, Vt = similar_output(svd_compact!, A, s_axis, alg) # allocate output for bI in eachblockstoredindex(A) @@ -79,13 +83,23 @@ function MatrixAlgebraKit.initialize_output( return U, S, Vt end +function similar_output( + ::typeof(svd_full!), A, s_axis::AbstractUnitRange, alg::MatrixAlgebraKit.AbstractAlgorithm +) + U = similar(A, axes(A, 1), s_axis) + T = real(eltype(A)) + S = similar(A, T, (s_axis, axes(A, 2))) + Vt = similar(A, axes(A, 2), axes(A, 2)) + return U, S, Vt +end + function MatrixAlgebraKit.initialize_output( ::typeof(svd_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm ) bm, bn = blocksize(A) - brows = blocklengths(axes(A, 1)) - slengths = copy(brows) + brows = eachblockaxis(axes(A, 1)) + s_axes = similar(brows) # fill in values for blocks that are present bIs = collect(eachblockstoredindex(A)) @@ -93,8 +107,7 @@ function MatrixAlgebraKit.initialize_output( bcolIs = Int.(last.(Tuple.(bIs))) for bI in eachblockstoredindex(A) row, col = Int.(Tuple(bI)) - nrows = brows[row] - slengths[col] = nrows + s_axes[col] = brows[row] end # fill in values for blocks that aren't present, pairing them in order of occurence @@ -102,17 +115,14 @@ function MatrixAlgebraKit.initialize_output( emptyrows = setdiff(1:bm, browIs) emptycols = setdiff(1:bn, bcolIs) for (row, col) in zip(emptyrows, emptycols) - slengths[col] = brows[row] + s_axes[col] = brows[row] end for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows)) - slengths[bn + i] = brows[emptyrows[k]] + s_axes[bn + i] = brows[emptyrows[k]] end - s_axis = blockedrange(slengths) - U = similar(A, axes(A, 1), s_axis) - Tr = real(eltype(A)) - S = similar(A, Tr, (s_axis, axes(A, 2))) - Vt = similar(A, axes(A, 2), axes(A, 2)) + s_axis = mortar_axis(s_axes) + U, S, Vt = similar_output(svd_full!, A, s_axis, alg) # allocate output for bI in eachblockstoredindex(A) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 0029c7a1..e3362128 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -45,6 +45,22 @@ function MatrixAlgebraKit.findtruncated( return indexmask end +function similar_truncate( + ::typeof(svd_trunc!), + (U, S, Vᴴ)::TBlockUSVᴴ, + strategy::BlockPermutedDiagonalTruncationStrategy, + indexmask=MatrixAlgebraKit.findtruncated(diagview(S), strategy), +) + ax = axes(S, 1) + counter = Base.Fix1(count, Base.Fix1(getindex, indexmask)) + s_lengths = filter!(>(0), map(counter, blocks(ax))) + s_axis = blockedrange(s_lengths) + Ũ = similar(U, axes(U, 1), s_axis) + S̃ = similar(S, s_axis, s_axis) + Ṽᴴ = similar(Vᴴ, s_axis, axes(Vᴴ, 2)) + return Ũ, S̃, Ṽᴴ +end + function MatrixAlgebraKit.truncate!( ::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, @@ -54,13 +70,7 @@ function MatrixAlgebraKit.truncate!( # first determine the block structure of the output to avoid having assumptions on the # data structures - ax = axes(S, 1) - counter = Base.Fix1(count, Base.Fix1(getindex, indexmask)) - Slengths = filter!(>(0), map(counter, blocks(ax))) - Sax = blockedrange(Slengths) - Ũ = similar(U, axes(U, 1), Sax) - S̃ = similar(S, Sax, Sax) - Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2)) + Ũ, S̃, Ṽᴴ = similar_truncate(svd_trunc!, (U, S, Vᴴ), strategy, indexmask) # then loop over the blocks and assign the data # TODO: figure out if we can presort and loop over the blocks - @@ -70,6 +80,7 @@ function MatrixAlgebraKit.truncate!( bI_Vᴴs = collect(eachblockstoredindex(Vᴴ)) I′ = 0 # number of skipped blocks that got fully truncated + ax = axes(S, 1) for I in 1:blocksize(ax, 1) b = ax[Block(I)] mask = indexmask[b]