@@ -3,20 +3,74 @@ const AbstractBlockSparseMatrix{T} = AbstractBlockSparseArray{T,2}
33# SVD is implemented by trying to
44# 1. Attempt to find a block-diagonal implementation by permuting
55# 2. Fallback to AbstractBlockArray implementation via BlockedArray
6- function svd (
7- A:: AbstractBlockSparseMatrix ; full:: Bool = false , alg:: Algorithm = default_svd_alg (A)
8- )
9- T = LinearAlgebra. eigtype (eltype (A))
10- A′_and_blockperms = try_to_blockdiagonal (A)
116
12- if isnothing (A′_and_blockperms)
13- # not block-diagonal, fall back to dense case
14- Adense = eigencopy_oftype (A, T)
15- return svd! (Adense; full, alg)
7+ function eigencopy_oftype (A:: AbstractBlockSparseMatrix , T)
8+ if is_block_permutation_matrix (A)
9+ Acopy = similar (A, T)
10+ for bI in eachblockstoredindex (A)
11+ Acopy[bI] = eigencopy_oftype (A[bI], T)
12+ end
13+ return Acopy
14+ else
15+ return BlockedMatrix {T} (A)
16+ end
17+ end
18+
19+ function is_block_permutation_matrix (a:: AbstractBlockSparseMatrix )
20+ return allunique (first ∘ Tuple, eachblockstoredindex (a)) &&
21+ allunique (last ∘ Tuple, eachblockstoredindex (a))
22+ end
23+
24+ function _allocate_svd_output (A:: AbstractBlockSparseMatrix , full:: Bool , :: Algorithm )
25+ @assert ! full " TODO"
26+ bm, bn = blocksize (A)
27+ bmn = min (bm, bn)
28+
29+ brows = blocklengths (axes (A, 1 ))
30+ bcols = blocklengths (axes (A, 2 ))
31+ slengths = Vector {Int} (undef, bmn)
32+
33+ # fill in values for blocks that are present
34+ bIs = collect (eachblockstoredindex (A))
35+ browIs = Int .(first .(Tuple .(bIs)))
36+ bcolIs = Int .(last .(Tuple .(bIs)))
37+ for bI in eachblockstoredindex (A)
38+ row, col = Int .(Tuple (bI))
39+ nrows = brows[row]
40+ ncols = bcols[col]
41+ slengths[col] = min (nrows, ncols)
1642 end
1743
18- # compute block-by-block and permute back
19- A″, (I, J) = A′
20- F = svd! (eigencopy_oftype (A″, T); full, alg)
21- return SVD (F. U[I, J], F. S, F. Vt)
44+ # fill in values for blocks that aren't present, pairing them in order of occurence
45+ # this is a convention, which at least gives the expected results for blockdiagonal
46+ emptyrows = findall (∉ (browIs), 1 : bmn)
47+ emptycols = findall (∉ (bcolIs), 1 : bmn)
48+ for (row, col) in zip (emptyrows, emptycols)
49+ slengths[col] = min (brows[row], bcols[col])
50+ end
51+
52+ U = similar (A, axes (A, 1 ), blockedrange (slengths))
53+ S = similar (A, real (eltype (A)), blockedrange (slengths))
54+ Vt = similar (A, blockedrange (slengths), axes (A, 2 ))
55+
56+ return U, S, Vt
57+ end
58+
59+ function svd (A:: AbstractBlockSparseMatrix ; kwargs... )
60+ return svd! (eigencopy_oftype (A, LinearAlgebra. eigtype (eltype (A))); kwargs... )
61+ end
62+
63+ function svd! (
64+ A:: AbstractBlockSparseMatrix ; full:: Bool = false , alg:: Algorithm = default_svd_alg (A)
65+ )
66+ @assert is_block_permutation_matrix (A) " Cannot keep sparsity: use `svd` to convert to `BlockedMatrix"
67+ U, S, Vt = _allocate_svd_output (A, full, alg)
68+ for bI in eachblockstoredindex (A)
69+ bUSV = svd! (A[bI]; full, alg)
70+ brow, bcol = Int .(Tuple (bI))
71+ U[Block (brow, bcol)] = bUSV. U
72+ S[Block (bcol)] = bUSV. S
73+ Vt[Block (bcol, bcol)] = bUSV. Vt
74+ end
75+ return SVD (U, S, Vt)
2276end
0 commit comments