Skip to content

Add truncation functionality for SVD #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 9, 2025
1 change: 1 addition & 0 deletions src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")

# factorizations
include("factorizations/svd.jl")
include("factorizations/truncation.jl")

end
3 changes: 2 additions & 1 deletion src/factorizations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ using MatrixAlgebraKit: MatrixAlgebraKit, svd_compact!, svd_full!
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)

A wrapper for `MatrixAlgebraKit.AbstractAlgorithm` that implements the wrapped algorithm on
a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted block-diagonal matrix.
a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or
a block permuted block-diagonal matrix.
"""
struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
MatrixAlgebraKit.AbstractAlgorithm
Expand Down
85 changes: 85 additions & 0 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc!

"""
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)

A wrapper for `TruncationStrategy` that implements the wrapped strategy on a block-by-block
basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted
block-diagonal matrix.
"""
struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
strategy::T
end

const TBlockUSVᴴ = Tuple{
<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix
}

function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy
)
# TODO assert blockdiagonal
return MatrixAlgebraKit.truncate!(
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
)
end

function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!),
(U, S, Vᴴ)::TBlockUSVᴴ,
strategy::BlockPermutedDiagonalTruncationStrategy,
)
ind = MatrixAlgebraKit.findtruncated(diagview(S), strategy.strategy)
# cannot use regular slicing here: I want to slice without altering blockstructure
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
indexmask = falses(size(S, 1))
indexmask[ind] .= true

# first determine the block structure of the output to avoid having assumptions on the
# data structures
ax = axes(S, 1)
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
Slengths = filter!(>(0), map(counter, blocks(ax)))
Sax = blockedrange(Slengths)
Ũ = similar(U, axes(U, 1), Sax)
S̃ = similar(S, Sax, Sax)
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2))

# then loop over the blocks and assign the data
# TODO: figure out if we can presort and loop over the blocks -
# for now this has issues with missing blocks
bI_Us = collect(eachblockstoredindex(U))
bI_Ss = collect(eachblockstoredindex(S))
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))

I′ = 0 # number of skipped blocks that got fully truncated
for (I, b) in enumerate(blocks(ax))
mask = indexmask[b]

if !any(mask)
I′ += 1
continue
end

bU_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Us) error(
"No U-block found for $I"
)
bU = Tuple(bI_Us[bU_id])
Ũ[bU[1], bU[2] - Block(I′)] = view(U, bU...)[:, mask]

bVᴴ_id = @something findfirst(x -> first(Tuple(x)) == Block(I), bI_Vᴴs) error(
"No Vᴴ-block found for $I"
)
bVᴴ = Tuple(bI_Vᴴs[bVᴴ_id])
Ṽᴴ[bVᴴ[1] - Block(I′), bVᴴ[2]] = view(Vᴴ, bVᴴ...)[mask, :]

bS_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Ss) error(
"No S-block found for $I"
)
bS = Tuple(bI_Ss[bS_id])
S̃[(bS .- Block(I′))...] = Diagonal(diagview(view(S, bS...))[mask])
end

return Ũ, S̃, Ṽᴴ
end

56 changes: 55 additions & 1 deletion test/test_factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
using MatrixAlgebraKit: svd_compact, svd_full
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank
using LinearAlgebra: LinearAlgebra
using Random: Random
using Test: @inferred, @testset, @test
Expand Down Expand Up @@ -83,3 +83,57 @@ end
usv = svd_full(c)
@test test_svd(c, usv; full=true)
end

# svd_trunc!
# ----------

@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params
(m, n), T = first(test_params)
a = BlockSparseArray{T}(undef, m, n)

# test blockdiagonal
for i in LinearAlgebra.diagind(blocks(a))
I = CartesianIndices(blocks(a))[i]
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
end

minmn = min(size(a)...)
r = max(1, minmn - 2)

U1, S1, V1ᴴ = svd_trunc(a; trunc=truncrank(r))
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc=truncrank(r))
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ

@test (U1' * U1 LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)

# test permuted blockdiagonal
perm = Random.randperm(length(m))
b = a[Block.(perm), Block.(1:length(n))]
U1, S1, V1ᴴ = svd_trunc(b; trunc=truncrank(r))
U2, S2, V2ᴴ = svd_trunc(Matrix(b); trunc=truncrank(r))
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ

@test (U1' * U1 LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)

# test permuted blockdiagonal with missing row/col
I_removed = rand(eachblockstoredindex(b))
c = copy(b)
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
U1, S1, V1ᴴ = svd_trunc(c; trunc=truncrank(r))
U2, S2, V2ᴴ = svd_trunc(Matrix(c); trunc=truncrank(r))
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ

@test (U1' * U1 LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
end
Loading