Skip to content

Commit f42daa9

Browse files
authored
Simplified truncated SVD with logical indexing (#132)
1 parent 7d3f1bf commit f42daa9

File tree

6 files changed

+9
-65
lines changed

6 files changed

+9
-65
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.9"
4+
version = "0.7.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
66

77
[compat]
88
BlockArrays = "1"
9-
BlockSparseArrays = "0.6"
9+
BlockSparseArrays = "0.7"
1010
Documenter = "1"
1111
Literate = "2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55

66
[compat]
77
BlockArrays = "1"
8-
BlockSparseArrays = "0.6"
8+
BlockSparseArrays = "0.7"
99
Test = "1"

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,11 @@ function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) whe
432432
end
433433

434434
function SparseArraysBase.eachstoredindex(::IndexCartesian, a::SparseSubArrayBlocks)
435-
return filter(eachindex(a)) do I
435+
isempty(a) && return CartesianIndex{ndims(a)}[]
436+
inds = filter(eachindex(a)) do I
436437
return isstored(a, I)
437438
end
439+
return inds
438440

439441
## # TODO: This only works for blockwise slices, i.e. slices using
440442
## # `BlockSliceCollection`.

src/factorizations/truncation.jl

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -45,69 +45,11 @@ function MatrixAlgebraKit.findtruncated(
4545
return indexmask
4646
end
4747

48-
function similar_truncate(
49-
::typeof(svd_trunc!),
50-
(U, S, Vᴴ)::TBlockUSVᴴ,
51-
strategy::BlockPermutedDiagonalTruncationStrategy,
52-
indexmask=MatrixAlgebraKit.findtruncated(diagview(S), strategy),
53-
)
54-
ax = axes(S, 1)
55-
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
56-
s_lengths = filter!(>(0), map(counter, blocks(ax)))
57-
s_axis = blockedrange(s_lengths)
58-
= similar(U, axes(U, 1), s_axis)
59-
= similar(S, s_axis, s_axis)
60-
Ṽᴴ = similar(Vᴴ, s_axis, axes(Vᴴ, 2))
61-
return Ũ, S̃, Ṽᴴ
62-
end
63-
6448
function MatrixAlgebraKit.truncate!(
6549
::typeof(svd_trunc!),
6650
(U, S, Vᴴ)::TBlockUSVᴴ,
6751
strategy::BlockPermutedDiagonalTruncationStrategy,
6852
)
69-
indexmask = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
70-
71-
# first determine the block structure of the output to avoid having assumptions on the
72-
# data structures
73-
Ũ, S̃, Ṽᴴ = similar_truncate(svd_trunc!, (U, S, Vᴴ), strategy, indexmask)
74-
75-
# then loop over the blocks and assign the data
76-
# TODO: figure out if we can presort and loop over the blocks -
77-
# for now this has issues with missing blocks
78-
bI_Us = collect(eachblockstoredindex(U))
79-
bI_Ss = collect(eachblockstoredindex(S))
80-
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))
81-
82-
I′ = 0 # number of skipped blocks that got fully truncated
83-
ax = axes(S, 1)
84-
for I in 1:blocksize(ax, 1)
85-
b = ax[Block(I)]
86-
mask = indexmask[b]
87-
88-
if !any(mask)
89-
I′ += 1
90-
continue
91-
end
92-
93-
bU_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Us) error(
94-
"No U-block found for $I"
95-
)
96-
bU = Tuple(bI_Us[bU_id])
97-
Ũ[bU[1], bU[2] - Block(I′)] = view(U, bU...)[:, mask]
98-
99-
bVᴴ_id = @something findfirst(x -> first(Tuple(x)) == Block(I), bI_Vᴴs) error(
100-
"No Vᴴ-block found for $I"
101-
)
102-
bVᴴ = Tuple(bI_Vᴴs[bVᴴ_id])
103-
Ṽᴴ[bVᴴ[1] - Block(I′), bVᴴ[2]] = view(Vᴴ, bVᴴ...)[mask, :]
104-
105-
bS_id = findfirst(x -> last(Tuple(x)) == Block(I), bI_Ss)
106-
if !isnothing(bS_id)
107-
bS = Tuple(bI_Ss[bS_id])
108-
S̃[(bS .- Block(I′))...] = Diagonal(diagview(view(S, bS...))[mask])
109-
end
110-
end
111-
112-
return Ũ, S̃, Ṽᴴ
53+
I = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
54+
return (U[:, I], S[I, I], Vᴴ[I, :])
11355
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Adapt = "4"
2323
Aqua = "0.8"
2424
ArrayLayouts = "1"
2525
BlockArrays = "1"
26-
BlockSparseArrays = "0.6"
26+
BlockSparseArrays = "0.7"
2727
DiagonalArrays = "0.3"
2828
GPUArraysCore = "0.2"
2929
JLArrays = "0.2"

0 commit comments

Comments
 (0)