Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.6.9"
version = "0.6.10"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
4 changes: 3 additions & 1 deletion src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,11 @@ function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) whe
end

function SparseArraysBase.eachstoredindex(::IndexCartesian, a::SparseSubArrayBlocks)
return filter(eachindex(a)) do I
isempty(a) && return CartesianIndex{ndims(a)}[]
inds = filter(eachindex(a)) do I
return isstored(a, I)
end
return inds

## # TODO: This only works for blockwise slices, i.e. slices using
## # `BlockSliceCollection`.
Expand Down
62 changes: 2 additions & 60 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,69 +45,11 @@ function MatrixAlgebraKit.findtruncated(
return indexmask
end

function similar_truncate(
::typeof(svd_trunc!),
(U, S, Vᴴ)::TBlockUSVᴴ,
strategy::BlockPermutedDiagonalTruncationStrategy,
indexmask=MatrixAlgebraKit.findtruncated(diagview(S), strategy),
)
ax = axes(S, 1)
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
s_lengths = filter!(>(0), map(counter, blocks(ax)))
s_axis = blockedrange(s_lengths)
Ũ = similar(U, axes(U, 1), s_axis)
S̃ = similar(S, s_axis, s_axis)
Ṽᴴ = similar(Vᴴ, s_axis, axes(Vᴴ, 2))
return Ũ, S̃, Ṽᴴ
end

function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!),
(U, S, Vᴴ)::TBlockUSVᴴ,
strategy::BlockPermutedDiagonalTruncationStrategy,
)
indexmask = MatrixAlgebraKit.findtruncated(diagview(S), strategy)

# first determine the block structure of the output to avoid having assumptions on the
# data structures
Ũ, S̃, Ṽᴴ = similar_truncate(svd_trunc!, (U, S, Vᴴ), strategy, indexmask)

# then loop over the blocks and assign the data
# TODO: figure out if we can presort and loop over the blocks -
# for now this has issues with missing blocks
bI_Us = collect(eachblockstoredindex(U))
bI_Ss = collect(eachblockstoredindex(S))
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))

I′ = 0 # number of skipped blocks that got fully truncated
ax = axes(S, 1)
for I in 1:blocksize(ax, 1)
b = ax[Block(I)]
mask = indexmask[b]

if !any(mask)
I′ += 1
continue
end

bU_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Us) error(
"No U-block found for $I"
)
bU = Tuple(bI_Us[bU_id])
Ũ[bU[1], bU[2] - Block(I′)] = view(U, bU...)[:, mask]

bVᴴ_id = @something findfirst(x -> first(Tuple(x)) == Block(I), bI_Vᴴs) error(
"No Vᴴ-block found for $I"
)
bVᴴ = Tuple(bI_Vᴴs[bVᴴ_id])
Ṽᴴ[bVᴴ[1] - Block(I′), bVᴴ[2]] = view(Vᴴ, bVᴴ...)[mask, :]

bS_id = findfirst(x -> last(Tuple(x)) == Block(I), bI_Ss)
if !isnothing(bS_id)
bS = Tuple(bI_Ss[bS_id])
S̃[(bS .- Block(I′))...] = Diagonal(diagview(view(S, bS...))[mask])
end
end

return Ũ, S̃, Ṽᴴ
I = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
return (U[:, I], S[I, I], Vᴴ[I, :])
end
Loading