Skip to content

Commit 4d98d75

Browse files
committed
refactor svd to work with new algorithm types
1 parent 7ef0aff commit 4d98d75

File tree

1 file changed

+111
-140
lines changed

1 file changed

+111
-140
lines changed

src/factorizations/svd.jl

Lines changed: 111 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -29,59 +29,34 @@ function similar_output(
2929
end
3030

3131
function MatrixAlgebraKit.initialize_output(
32-
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
32+
::typeof(svd_compact!), ::AbstractBlockSparseMatrix, ::BlockPermutedDiagonalAlgorithm
33+
)
34+
return nothing
35+
end
36+
function MatrixAlgebraKit.initialize_output(
37+
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
3338
)
34-
bm, bn = blocksize(A)
35-
bmn = min(bm, bn)
36-
3739
brows = eachblockaxis(axes(A, 1))
3840
bcols = eachblockaxis(axes(A, 2))
39-
u_axes = similar(brows, bmn)
40-
v_axes = similar(brows, bmn)
41+
# using the property that zip stops as soon as one of the iterators is exhausted
42+
s_axes = map(splat(infimum), zip(brows, bcols))
43+
s_axis = mortar_axis(s_axes)
44+
S_axes = (s_axis, s_axis)
45+
U, S, Vᴴ = similar_output(svd_compact!, A, S_axes, alg)
4146

42-
# fill in values for blocks that are present
43-
bIs = collect(eachblockstoredindex(A))
44-
browIs = Int.(first.(Tuple.(bIs)))
45-
bcolIs = Int.(last.(Tuple.(bIs)))
4647
for bI in eachblockstoredindex(A)
47-
row, col = Int.(Tuple(bI))
48-
u_axes[col] = infimum(brows[row], bcols[col])
49-
v_axes[col] = infimum(bcols[col], brows[row])
50-
end
51-
52-
# fill in values for blocks that aren't present, pairing them in order of occurence
53-
# this is a convention, which at least gives the expected results for blockdiagonal
54-
emptyrows = setdiff(1:bm, browIs)
55-
emptycols = setdiff(1:bn, bcolIs)
56-
for (row, col) in zip(emptyrows, emptycols)
57-
u_axes[col] = infimum(brows[row], bcols[col])
58-
v_axes[col] = infimum(bcols[col], brows[row])
59-
end
60-
61-
u_axis = mortar_axis(u_axes)
62-
v_axis = mortar_axis(v_axes)
63-
S_axes = (u_axis, v_axis)
64-
U, S, Vt = similar_output(svd_compact!, A, S_axes, alg)
65-
66-
# allocate output
67-
for bI in eachblockstoredindex(A)
68-
brow, bcol = Tuple(bI)
6948
block = @view!(A[bI])
7049
block_alg = block_algorithm(alg, block)
71-
U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output(
50+
I = first(Tuple(bI)) # == last(Tuple(bI))
51+
U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit.initialize_output(
7252
svd_compact!, block, block_alg
7353
)
7454
end
7555

76-
# allocate output for blocks that aren't present -- do we also fill identities here?
77-
for (row, col) in zip(emptyrows, emptycols)
78-
@view!(U[Block(row, col)])
79-
@view!(Vt[Block(col, col)])
80-
end
81-
82-
return U, S, Vt
56+
return U, S, Vᴴ
8357
end
8458

59+
8560
function similar_output(
8661
::typeof(svd_full!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
8762
)
@@ -93,65 +68,39 @@ function similar_output(
9368
end
9469

9570
function MatrixAlgebraKit.initialize_output(
96-
::typeof(svd_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
71+
::typeof(svd_full!), ::AbstractBlockSparseMatrix, ::BlockPermutedDiagonalAlgorithm
9772
)
98-
bm, bn = blocksize(A)
99-
100-
brows = eachblockaxis(axes(A, 1))
101-
u_axes = similar(brows)
102-
103-
# fill in values for blocks that are present
104-
bIs = collect(eachblockstoredindex(A))
105-
browIs = Int.(first.(Tuple.(bIs)))
106-
bcolIs = Int.(last.(Tuple.(bIs)))
107-
for bI in eachblockstoredindex(A)
108-
row, col = Int.(Tuple(bI))
109-
u_axes[col] = brows[row]
110-
end
111-
112-
# fill in values for blocks that aren't present, pairing them in order of occurence
113-
# this is a convention, which at least gives the expected results for blockdiagonal
114-
emptyrows = setdiff(1:bm, browIs)
115-
emptycols = setdiff(1:bn, bcolIs)
116-
for (row, col) in zip(emptyrows, emptycols)
117-
u_axes[col] = brows[row]
118-
end
119-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
120-
u_axes[bn + i] = brows[emptyrows[k]]
121-
end
73+
return nothing
74+
end
12275

123-
u_axis = mortar_axis(u_axes)
124-
S_axes = (u_axis, axes(A, 2))
125-
U, S, Vt = similar_output(svd_full!, A, S_axes, alg)
76+
function MatrixAlgebraKit.initialize_output(
77+
::typeof(svd_full!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
78+
)
79+
U, S, Vᴴ = similar_output(svd_full!, A, axes(A), alg)
12680

127-
# allocate output
12881
for bI in eachblockstoredindex(A)
129-
brow, bcol = Tuple(bI)
13082
block = @view!(A[bI])
13183
block_alg = block_algorithm(alg, block)
132-
U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output(
84+
I = first(Tuple(bI)) # == last(Tuple(bI))
85+
U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit.initialize_output(
13386
svd_full!, block, block_alg
13487
)
13588
end
13689

137-
# allocate output for blocks that aren't present -- do we also fill identities here?
138-
for (row, col) in zip(emptyrows, emptycols)
139-
@view!(U[Block(row, col)])
140-
@view!(Vt[Block(col, col)])
141-
end
142-
# also handle extra rows/cols
143-
for i in (length(emptyrows) + 1):length(emptycols)
144-
@view!(Vt[Block(emptycols[i], emptycols[i])])
145-
end
146-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
147-
@view!(U[Block(emptyrows[k], bn + i)])
148-
end
90+
return U, S, Vᴴ
91+
end
14992

150-
return U, S, Vt
93+
function MatrixAlgebraKit.check_input(
94+
::typeof(svd_compact!),
95+
A::AbstractBlockSparseMatrix,
96+
USVᴴ,
97+
::BlockPermutedDiagonalAlgorithm,
98+
)
99+
@assert isblockpermuteddiagonal(A)
151100
end
152101

153102
function MatrixAlgebraKit.check_input(
154-
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ)
103+
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ), ::BlockDiagonalAlgorithm
155104
)
156105
@assert isa(U, AbstractBlockSparseMatrix) &&
157106
isa(S, AbstractBlockSparseMatrix) &&
@@ -160,11 +109,19 @@ function MatrixAlgebraKit.check_input(
160109
@assert real(eltype(A)) == eltype(S)
161110
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 2)
162111
@assert axes(S, 1) == axes(S, 2)
112+
@assert isblockdiagonal(A)
163113
return nothing
164114
end
165115

166116
function MatrixAlgebraKit.check_input(
167-
::typeof(svd_full!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ)
117+
::typeof(svd_full!), A::AbstractBlockSparseMatrix, USVᴴ, ::BlockPermutedDiagonalAlgorithm
118+
)
119+
@assert isblockpermuteddiagonal(A)
120+
return nothing
121+
end
122+
123+
function MatrixAlgebraKit.check_input(
124+
::typeof(svd_full!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ), ::BlockDiagonalAlgorithm
168125
)
169126
@assert isa(U, AbstractBlockSparseMatrix) &&
170127
isa(S, AbstractBlockSparseMatrix) &&
@@ -173,78 +130,92 @@ function MatrixAlgebraKit.check_input(
173130
@assert real(eltype(A)) == eltype(S)
174131
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 1) == axes(Vᴴ, 2)
175132
@assert axes(S, 2) == axes(A, 2)
133+
@assert isblockdiagonal(A)
176134
return nothing
177135
end
178136

179137
function MatrixAlgebraKit.svd_compact!(
180-
A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm
138+
A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm
181139
)
182-
check_input(svd_compact!, A, (U, S, Vᴴ))
140+
check_input(svd_compact!, A, USVᴴ, alg)
183141

184-
# do decomposition on each block
185-
for bI in eachblockstoredindex(A)
186-
brow, bcol = Tuple(bI)
187-
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol]))
188-
block = @view!(A[bI])
189-
block_alg = block_algorithm(alg, block)
190-
usvᴴ′ = svd_compact!(block, usvᴴ, block_alg)
191-
@assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place"
192-
end
142+
Ad, rowperm, colperm = blockdiagonalize(A)
143+
Ud, S, Vᴴd = svd_compact!(Ad, BlockDiagonalAlgorithm(alg))
144+
145+
inv_rowperm = Block.(invperm(Int.(rowperm)))
146+
U = Ud[inv_rowperm, :]
147+
148+
inv_colperm = Block.(invperm(Int.(colperm)))
149+
Vᴴ = Vᴴd[:, inv_colperm]
150+
151+
return U, S, Vᴴ
152+
end
193153

194-
# fill in identities for blocks that aren't present
195-
bIs = collect(eachblockstoredindex(A))
196-
browIs = Int.(first.(Tuple.(bIs)))
197-
bcolIs = Int.(last.(Tuple.(bIs)))
198-
emptyrows = setdiff(1:blocksize(A, 1), browIs)
199-
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
200-
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
201-
# U[Block(row, col)] = LinearAlgebra.I
202-
# Vᴴ[Block(col, col)] = LinearAlgebra.I
203-
for (row, col) in zip(emptyrows, emptycols)
204-
copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I)
205-
copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I)
154+
function MatrixAlgebraKit.svd_compact!(
155+
A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockDiagonalAlgorithm
156+
)
157+
check_input(svd_compact!, A, (U, S, Vᴴ), alg)
158+
159+
for I in 1:min(blocksize(A)...)
160+
bI = Block(I, I)
161+
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
162+
usvᴴ = (@view!(U[bI]), @view!(S[bI]), @view!(Vᴴ[bI]))
163+
block = @view!(A[bI])
164+
block_alg = block_algorithm(alg, block)
165+
usvᴴ′ = svd_compact!(block, usvᴴ, block_alg)
166+
@assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place"
167+
else
168+
copyto!(@view!(U[bI]), LinearAlgebra.I)
169+
copyto!(@view!(Vᴴ[bI]), LinearAlgebra.I)
170+
end
206171
end
207172

208-
return (U, S, Vᴴ)
173+
return U, S, Vᴴ
209174
end
210175

211176
function MatrixAlgebraKit.svd_full!(
212-
A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm
177+
A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm
213178
)
214-
check_input(svd_full!, A, (U, S, Vᴴ))
179+
check_input(svd_full!, A, USVᴴ, alg)
215180

216-
# do decomposition on each block
217-
for bI in eachblockstoredindex(A)
218-
brow, bcol = Tuple(bI)
219-
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol]))
220-
block = @view!(A[bI])
221-
block_alg = block_algorithm(alg, block)
222-
usvᴴ′ = svd_full!(block, usvᴴ, block_alg)
223-
@assert usvᴴ === usvᴴ′ "svd_full! might not be in-place"
224-
end
181+
Ad, rowperm, colperm = blockdiagonalize(A)
182+
Ud, S, Vᴴd = svd_full!(Ad, BlockDiagonalAlgorithm(alg))
183+
184+
inv_rowperm = Block.(invperm(Int.(rowperm)))
185+
U = Ud[inv_rowperm, :]
225186

226-
# fill in identities for blocks that aren't present
227-
bIs = collect(eachblockstoredindex(A))
228-
browIs = Int.(first.(Tuple.(bIs)))
229-
bcolIs = Int.(last.(Tuple.(bIs)))
230-
emptyrows = setdiff(1:blocksize(A, 1), browIs)
231-
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
232-
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
233-
# U[Block(row, col)] = LinearAlgebra.I
234-
# Vt[Block(col, col)] = LinearAlgebra.I
235-
for (row, col) in zip(emptyrows, emptycols)
236-
copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I)
237-
copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I)
187+
inv_colperm = Block.(invperm(Int.(colperm)))
188+
Vᴴ = Vᴴd[:, inv_colperm]
189+
190+
return U, S, Vᴴ
191+
end
192+
193+
function MatrixAlgebraKit.svd_full!(
194+
A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockDiagonalAlgorithm
195+
)
196+
check_input(svd_full!, A, (U, S, Vᴴ), alg)
197+
198+
for I in 1:min(blocksize(A)...)
199+
bI = Block(I, I)
200+
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
201+
usvᴴ = (@view!(U[bI]), @view!(S[bI]), @view!(Vᴴ[bI]))
202+
block = @view!(A[bI])
203+
block_alg = block_algorithm(alg, block)
204+
usvᴴ′ = svd_full!(block, usvᴴ, block_alg)
205+
@assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place"
206+
else
207+
copyto!(@view!(U[bI]), LinearAlgebra.I)
208+
copyto!(@view!(Vᴴ[bI]), LinearAlgebra.I)
209+
end
238210
end
239211

240-
# also handle extra rows/cols
241-
for i in (length(emptyrows) + 1):length(emptycols)
242-
copyto!(@view!(Vᴴ[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I)
212+
# Complete the unitaries for rectangular inputs
213+
for I in blocksize(A, 2)+1:blocksize(A, 1)
214+
copyto!(@view!(U[Block(I, I)]), LinearAlgebra.I)
243215
end
244-
bn = blocksize(A, 2)
245-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
246-
copyto!(@view!(U[Block(emptyrows[k], bn + i)]), LinearAlgebra.I)
216+
for I in blocksize(A, 1)+1:blocksize(A, 2)
217+
copyto!(@view!(Vᴴ[Block(I, I)]), LinearAlgebra.I)
247218
end
248219

249-
return (U, S, Vᴴ)
220+
return U, S, Vᴴ
250221
end

0 commit comments

Comments
 (0)