@@ -3,20 +3,74 @@ const AbstractBlockSparseMatrix{T} = AbstractBlockSparseArray{T,2}
3
3
# SVD is implemented by trying to
4
4
# 1. Attempt to find a block-diagonal implementation by permuting
5
5
# 2. Fallback to AbstractBlockArray implementation via BlockedArray
6
- function svd (
7
- A:: AbstractBlockSparseMatrix ; full:: Bool = false , alg:: Algorithm = default_svd_alg (A)
8
- )
9
- T = LinearAlgebra. eigtype (eltype (A))
10
- A′_and_blockperms = try_to_blockdiagonal (A)
11
6
12
- if isnothing (A′_and_blockperms)
13
- # not block-diagonal, fall back to dense case
14
- Adense = eigencopy_oftype (A, T)
15
- return svd! (Adense; full, alg)
7
+ function eigencopy_oftype (A:: AbstractBlockSparseMatrix , T)
8
+ if is_block_permutation_matrix (A)
9
+ Acopy = similar (A, T)
10
+ for bI in eachblockstoredindex (A)
11
+ Acopy[bI] = eigencopy_oftype (A[bI], T)
12
+ end
13
+ return Acopy
14
+ else
15
+ return BlockedMatrix {T} (A)
16
+ end
17
+ end
18
+
19
+ function is_block_permutation_matrix (a:: AbstractBlockSparseMatrix )
20
+ return allunique (first ∘ Tuple, eachblockstoredindex (a)) &&
21
+ allunique (last ∘ Tuple, eachblockstoredindex (a))
22
+ end
23
+
24
+ function _allocate_svd_output (A:: AbstractBlockSparseMatrix , full:: Bool , :: Algorithm )
25
+ @assert ! full " TODO"
26
+ bm, bn = blocksize (A)
27
+ bmn = min (bm, bn)
28
+
29
+ brows = blocklengths (axes (A, 1 ))
30
+ bcols = blocklengths (axes (A, 2 ))
31
+ slengths = Vector {Int} (undef, bmn)
32
+
33
+ # fill in values for blocks that are present
34
+ bIs = collect (eachblockstoredindex (A))
35
+ browIs = Int .(first .(Tuple .(bIs)))
36
+ bcolIs = Int .(last .(Tuple .(bIs)))
37
+ for bI in eachblockstoredindex (A)
38
+ row, col = Int .(Tuple (bI))
39
+ nrows = brows[row]
40
+ ncols = bcols[col]
41
+ slengths[col] = min (nrows, ncols)
16
42
end
17
43
18
- # compute block-by-block and permute back
19
- A″, (I, J) = A′
20
- F = svd! (eigencopy_oftype (A″, T); full, alg)
21
- return SVD (F. U[I, J], F. S, F. Vt)
44
+ # fill in values for blocks that aren't present, pairing them in order of occurence
45
+ # this is a convention, which at least gives the expected results for blockdiagonal
46
+ emptyrows = findall (∉ (browIs), 1 : bmn)
47
+ emptycols = findall (∉ (bcolIs), 1 : bmn)
48
+ for (row, col) in zip (emptyrows, emptycols)
49
+ slengths[col] = min (brows[row], bcols[col])
50
+ end
51
+
52
+ U = similar (A, axes (A, 1 ), blockedrange (slengths))
53
+ S = similar (A, real (eltype (A)), blockedrange (slengths))
54
+ Vt = similar (A, blockedrange (slengths), axes (A, 2 ))
55
+
56
+ return U, S, Vt
57
+ end
58
+
59
+ function svd (A:: AbstractBlockSparseMatrix ; kwargs... )
60
+ return svd! (eigencopy_oftype (A, LinearAlgebra. eigtype (eltype (A))); kwargs... )
61
+ end
62
+
63
+ function svd! (
64
+ A:: AbstractBlockSparseMatrix ; full:: Bool = false , alg:: Algorithm = default_svd_alg (A)
65
+ )
66
+ @assert is_block_permutation_matrix (A) " Cannot keep sparsity: use `svd` to convert to `BlockedMatrix"
67
+ U, S, Vt = _allocate_svd_output (A, full, alg)
68
+ for bI in eachblockstoredindex (A)
69
+ bUSV = svd! (A[bI]; full, alg)
70
+ brow, bcol = Int .(Tuple (bI))
71
+ U[Block (brow, bcol)] = bUSV. U
72
+ S[Block (bcol)] = bUSV. S
73
+ Vt[Block (bcol, bcol)] = bUSV. Vt
74
+ end
75
+ return SVD (U, S, Vt)
22
76
end
0 commit comments