Skip to content

Commit 1cf4cb9

Browse files
committed
Rewrite svd to work with rectangular matrices
1 parent 964a8be commit 1cf4cb9

File tree

2 files changed

+69
-44
lines changed

2 files changed

+69
-44
lines changed

src/abstractblocksparsearray/abstractblocksparsematrix.jl

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,74 @@ const AbstractBlockSparseMatrix{T} = AbstractBlockSparseArray{T,2}
33
# SVD is implemented by trying to
44
# 1. Attempt to find a block-diagonal implementation by permuting
55
# 2. Fallback to AbstractBlockArray implementation via BlockedArray
6-
function svd(
7-
A::AbstractBlockSparseMatrix; full::Bool=false, alg::Algorithm=default_svd_alg(A)
8-
)
9-
T = LinearAlgebra.eigtype(eltype(A))
10-
A′_and_blockperms = try_to_blockdiagonal(A)
116

12-
if isnothing(A′_and_blockperms)
13-
# not block-diagonal, fall back to dense case
14-
Adense = eigencopy_oftype(A, T)
15-
return svd!(Adense; full, alg)
7+
function eigencopy_oftype(A::AbstractBlockSparseMatrix, T)
8+
if is_block_permutation_matrix(A)
9+
Acopy = similar(A, T)
10+
for bI in eachblockstoredindex(A)
11+
Acopy[bI] = eigencopy_oftype(A[bI], T)
12+
end
13+
return Acopy
14+
else
15+
return BlockedMatrix{T}(A)
16+
end
17+
end
18+
19+
function is_block_permutation_matrix(a::AbstractBlockSparseMatrix)
20+
return allunique(first Tuple, eachblockstoredindex(a)) &&
21+
allunique(last Tuple, eachblockstoredindex(a))
22+
end
23+
24+
function _allocate_svd_output(A::AbstractBlockSparseMatrix, full::Bool, ::Algorithm)
25+
@assert !full "TODO"
26+
bm, bn = blocksize(A)
27+
bmn = min(bm, bn)
28+
29+
brows = blocklengths(axes(A, 1))
30+
bcols = blocklengths(axes(A, 2))
31+
slengths = Vector{Int}(undef, bmn)
32+
33+
# fill in values for blocks that are present
34+
bIs = collect(eachblockstoredindex(A))
35+
browIs = Int.(first.(Tuple.(bIs)))
36+
bcolIs = Int.(last.(Tuple.(bIs)))
37+
for bI in eachblockstoredindex(A)
38+
row, col = Int.(Tuple(bI))
39+
nrows = brows[row]
40+
ncols = bcols[col]
41+
slengths[col] = min(nrows, ncols)
1642
end
1743

18-
# compute block-by-block and permute back
19-
A″, (I, J) = A′
20-
F = svd!(eigencopy_oftype(A″, T); full, alg)
21-
return SVD(F.U[I, J], F.S, F.Vt)
44+
# fill in values for blocks that aren't present, pairing them in order of occurence
45+
# this is a convention, which at least gives the expected results for blockdiagonal
46+
emptyrows = findall((browIs), 1:bmn)
47+
emptycols = findall((bcolIs), 1:bmn)
48+
for (row, col) in zip(emptyrows, emptycols)
49+
slengths[col] = min(brows[row], bcols[col])
50+
end
51+
52+
U = similar(A, axes(A, 1), blockedrange(slengths))
53+
S = similar(A, real(eltype(A)), blockedrange(slengths))
54+
Vt = similar(A, blockedrange(slengths), axes(A, 2))
55+
56+
return U, S, Vt
57+
end
58+
59+
function svd(A::AbstractBlockSparseMatrix; kwargs...)
60+
return svd!(eigencopy_oftype(A, LinearAlgebra.eigtype(eltype(A))); kwargs...)
61+
end
62+
63+
function svd!(
64+
A::AbstractBlockSparseMatrix; full::Bool=false, alg::Algorithm=default_svd_alg(A)
65+
)
66+
@assert is_block_permutation_matrix(A) "Cannot keep sparsity: use `svd` to convert to `BlockedMatrix"
67+
U, S, Vt = _allocate_svd_output(A, full, alg)
68+
for bI in eachblockstoredindex(A)
69+
bUSV = svd!(A[bI]; full, alg)
70+
brow, bcol = Int.(Tuple(bI))
71+
U[Block(brow, bcol)] = bUSV.U
72+
S[Block(bcol)] = bUSV.S
73+
Vt[Block(bcol, bcol)] = bUSV.Vt
74+
end
75+
return SVD(U, S, Vt)
2276
end

src/blocksparsearray/blockdiagonalarray.jl

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,9 @@ function BlockDiagonal(blocks::AbstractVector{<:AbstractMatrix})
1111
)
1212
end
1313

14-
# Cast to block-diagonal implementation if permuted-blockdiagonal
15-
function try_to_blockdiagonal_perm(A)
16-
inds = map(x -> Int.(Tuple(x)), vec(collect(block_stored_indices(A))))
17-
I = first.(inds)
18-
allunique(I) || return nothing
19-
J = last.(inds)
20-
p = sortperm(J)
21-
Jsorted = J[p]
22-
allunique(Jsorted) || return nothing
23-
return Block.(I[p], Jsorted)
24-
end
25-
26-
"""
27-
try_to_blockdiagonal(A)
28-
29-
Attempt to find a permutation of blocks that makes `A` blockdiagonal. If unsuccesful,
30-
returns nothing, otherwise returns both the blockdiagonal `B` as well as the permutation `I, J`.
31-
"""
32-
function try_to_blockdiagonal(A::AbstractBlockSparseMatrix)
33-
perm = try_to_blockdiagonal_perm(A)
34-
isnothing(perm) && return perm
35-
I = first.(Tuple.(perm))
36-
J = last.(Tuple.(perm))
37-
diagblocks = map(invperm(I), J) do i, j
38-
return A[Block(i, j)]
39-
end
40-
return BlockDiagonal(diagblocks), perm
41-
end
42-
4314
# SVD implementation
44-
function eigencopy_oftype(A::BlockDiagonal, S)
45-
diag = map(Base.Fix2(eigencopy_oftype, S), A.blocks.diag)
15+
function eigencopy_oftype(A::BlockDiagonal, T)
16+
diag = map(Base.Fix2(eigencopy_oftype, T), A.blocks.diag)
4617
return BlockDiagonal(diag)
4718
end
4819

0 commit comments

Comments
 (0)