Skip to content

Commit 7a73cc0

Browse files
committed
also update svd_full
1 parent 2b2deac commit 7a73cc0

File tree

2 files changed

+37
-41
lines changed

2 files changed

+37
-41
lines changed

src/factorizations/svd.jl

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -112,55 +112,51 @@ function MatrixAlgebraKit.initialize_output(
112112
bm, bn = blocksize(A)
113113

114114
brows = eachblockaxis(axes(A, 1))
115-
u_axes = similar(brows)
115+
bcols = eachblockaxis(axes(A, 2))
116+
u_axes = similar(brows, bm)
117+
v_axes = similar(bcols, bn)
116118

117119
# fill in values for blocks that are present
118-
bIs = collect(eachblockstoredindex(A))
120+
bIs = sort!(collect(eachblockstoredindex(A)), by=Int last Tuple)
119121
browIs = Int.(first.(Tuple.(bIs)))
120122
bcolIs = Int.(last.(Tuple.(bIs)))
121-
for bI in eachblockstoredindex(A)
123+
for (I, bI) in enumerate(bIs)
122124
row, col = Int.(Tuple(bI))
123-
u_axes[col] = brows[row]
125+
u_axes[I] = brows[row]
126+
v_axes[I] = bcols[col]
124127
end
125128

126129
# fill in values for blocks that aren't present, pairing them in order of occurence
127130
# this is a convention, which at least gives the expected results for blockdiagonal
128131
emptyrows = setdiff(1:bm, browIs)
132+
u_axes[length(bIs) .+ (1:length(emptyrows))] .= brows[emptyrows]
129133
emptycols = setdiff(1:bn, bcolIs)
130-
for (row, col) in zip(emptyrows, emptycols)
131-
u_axes[col] = brows[row]
132-
end
133-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
134-
u_axes[bn + i] = brows[emptyrows[k]]
135-
end
136-
134+
v_axes[length(bIs) .+ (1:length(emptycols))] .= bcols[emptycols]
135+
137136
u_axis = mortar_axis(u_axes)
138-
S_axes = (u_axis, axes(A, 2))
137+
v_axis = mortar_axis(@show v_axes)
138+
S_axes = (u_axis, v_axis)
139139
U, S, Vt = similar_output(svd_full!, A, S_axes, alg)
140140

141141
# allocate output
142-
for bI in eachblockstoredindex(A)
142+
for (I, bI) in enumerate(bIs)
143143
brow, bcol = Tuple(bI)
144+
bcol′ = Block(I)
144145
block = @view!(A[bI])
145146
block_alg = block_algorithm(alg, block)
146-
U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output(
147+
U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output(
147148
svd_full!, block, block_alg
148149
)
149150
end
150151

151152
# allocate output for blocks that aren't present -- do we also fill identities here?
152-
for (row, col) in zip(emptyrows, emptycols)
153-
@view!(U[Block(row, col)])
154-
@view!(Vt[Block(col, col)])
153+
for (I, row) in enumerate(emptyrows)
154+
@view!(U[Block(row, I)])
155155
end
156-
# also handle extra rows/cols
157-
for i in (length(emptyrows) + 1):length(emptycols)
158-
@view!(Vt[Block(emptycols[i], emptycols[i])])
156+
for (I, col) in enumerate(emptycols)
157+
@view!(Vt[Block(I, col)])
159158
end
160-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
161-
@view!(U[Block(emptyrows[k], bn + i)])
162-
end
163-
159+
164160
return U, S, Vt
165161
end
166162

@@ -185,8 +181,7 @@ function MatrixAlgebraKit.check_input(
185181
isa(Vᴴ, AbstractBlockSparseMatrix)
186182
@assert eltype(A) == eltype(U) == eltype(Vᴴ)
187183
@assert real(eltype(A)) == eltype(S)
188-
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 1) == axes(Vᴴ, 2)
189-
@assert axes(S, 2) == axes(A, 2)
184+
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 2)
190185
return nothing
191186
end
192187

@@ -208,7 +203,6 @@ function MatrixAlgebraKit.svd_compact!(
208203
end
209204

210205
# fill in identities for blocks that aren't present
211-
bIs = collect(eachblockstoredindex(A))
212206
browIs = Int.(first.(Tuple.(bIs)))
213207
bcolIs = Int.(last.(Tuple.(bIs)))
214208
emptyrows = setdiff(1:blocksize(A, 1), browIs)
@@ -230,36 +224,30 @@ function MatrixAlgebraKit.svd_full!(
230224
check_input(svd_full!, A, (U, S, Vᴴ))
231225

232226
# do decomposition on each block
233-
for bI in eachblockstoredindex(A)
227+
bIs = sort!(collect(eachblockstoredindex(A)); by=Int last Tuple)
228+
for (I, bI) in enumerate(bIs)
234229
brow, bcol = Tuple(bI)
235-
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol]))
230+
bcol′ = Block(I)
231+
usvᴴ = (@view!(U[brow, bcol′]), @view!(S[bcol′, bcol′]), @view!(Vᴴ[bcol′, bcol]))
236232
block = @view!(A[bI])
237233
block_alg = block_algorithm(alg, block)
238234
usvᴴ′ = svd_full!(block, usvᴴ, block_alg)
239235
@assert usvᴴ === usvᴴ′ "svd_full! might not be in-place"
240236
end
241237

242238
# fill in identities for blocks that aren't present
243-
bIs = collect(eachblockstoredindex(A))
244239
browIs = Int.(first.(Tuple.(bIs)))
245240
bcolIs = Int.(last.(Tuple.(bIs)))
246241
emptyrows = setdiff(1:blocksize(A, 1), browIs)
247242
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
248243
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
249244
# U[Block(row, col)] = LinearAlgebra.I
250245
# Vt[Block(col, col)] = LinearAlgebra.I
251-
for (row, col) in zip(emptyrows, emptycols)
252-
copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I)
253-
copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I)
254-
end
255-
256-
# also handle extra rows/cols
257-
for i in (length(emptyrows) + 1):length(emptycols)
258-
copyto!(@view!(Vᴴ[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I)
246+
for (I, row) in enumerate(emptyrows)
247+
copyto!(@view!(U[Block(row, length(bIs) + I)]), LinearAlgebra.I)
259248
end
260-
bn = blocksize(A, 2)
261-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
262-
copyto!(@view!(U[Block(emptyrows[k], bn + i)]), LinearAlgebra.I)
249+
for (I, col) in enumerate(emptycols)
250+
copyto!(@view!(Vᴴ[Block(length(bIs) + I, col)]), LinearAlgebra.I)
263251
end
264252

265253
return (U, S, Vᴴ)

test/test_issues.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,12 @@ using Test: @test, @testset
1414
@test U * S * Vᴴ a
1515
@test U' * U LinearAlgebra.I
1616
@test Vᴴ * Vᴴ' LinearAlgebra.I
17+
18+
U, S, Vᴴ = svd_full(a);
19+
20+
@test U * S * Vᴴ a
21+
@test U' * U LinearAlgebra.I
22+
@test U * U' LinearAlgebra.I
23+
@test Vᴴ * Vᴴ' LinearAlgebra.I
24+
@test Vᴴ' * Vᴴ LinearAlgebra.I
1725
end

0 commit comments

Comments
 (0)