1
- using MatrixAlgebraKit: MatrixAlgebraKit, default_svd_algorithm, svd_compact!, svd_full!
1
+ using MatrixAlgebraKit:
2
+ MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!
2
3
3
4
"""
4
5
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)
@@ -152,45 +153,40 @@ function MatrixAlgebraKit.initialize_output(
152
153
end
153
154
154
155
function MatrixAlgebraKit. check_input (
155
- :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , USVᴴ
156
+ :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ)
156
157
)
157
- U, S, Vt = USVᴴ
158
158
@assert isa (U, AbstractBlockSparseMatrix) &&
159
159
isa (S, AbstractBlockSparseMatrix) &&
160
- isa (Vt , AbstractBlockSparseMatrix)
161
- @assert eltype (A) == eltype (U) == eltype (Vt )
160
+ isa (Vᴴ , AbstractBlockSparseMatrix)
161
+ @assert eltype (A) == eltype (U) == eltype (Vᴴ )
162
162
@assert real (eltype (A)) == eltype (S)
163
- @assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vt , 2 )
163
+ @assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vᴴ , 2 )
164
164
@assert axes (S, 1 ) == axes (S, 2 )
165
-
166
165
return nothing
167
166
end
168
167
169
168
function MatrixAlgebraKit. check_input (
170
- :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , USVᴴ
169
+ :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ)
171
170
)
172
- U, S, Vt = USVᴴ
173
171
@assert isa (U, AbstractBlockSparseMatrix) &&
174
172
isa (S, AbstractBlockSparseMatrix) &&
175
- isa (Vt , AbstractBlockSparseMatrix)
176
- @assert eltype (A) == eltype (U) == eltype (Vt )
173
+ isa (Vᴴ , AbstractBlockSparseMatrix)
174
+ @assert eltype (A) == eltype (U) == eltype (Vᴴ )
177
175
@assert real (eltype (A)) == eltype (S)
178
- @assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vt , 1 ) == axes (Vt , 2 )
176
+ @assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vᴴ , 1 ) == axes (Vᴴ , 2 )
179
177
@assert axes (S, 2 ) == axes (A, 2 )
180
-
181
178
return nothing
182
179
end
183
180
184
181
function MatrixAlgebraKit. svd_compact! (
185
- A:: AbstractBlockSparseMatrix , USVᴴ , alg:: BlockPermutedDiagonalAlgorithm
182
+ A:: AbstractBlockSparseMatrix , (U, S, Vᴴ) , alg:: BlockPermutedDiagonalAlgorithm
186
183
)
187
- MatrixAlgebraKit. check_input (svd_compact!, A, USVᴴ)
188
- U, S, Vt = USVᴴ
184
+ check_input (svd_compact!, A, (U, S, Vᴴ))
189
185
190
186
# do decomposition on each block
191
187
for bI in eachblockstoredindex (A)
192
188
brow, bcol = Tuple (bI)
193
- usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vt [bcol, bcol]))
189
+ usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vᴴ [bcol, bcol]))
194
190
usvᴴ′ = svd_compact! (@view! (A[bI]), usvᴴ, alg. alg)
195
191
@assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
196
192
end
@@ -203,25 +199,24 @@ function MatrixAlgebraKit.svd_compact!(
203
199
emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
204
200
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
205
201
# U[Block(row, col)] = LinearAlgebra.I
206
- # Vt [Block(col, col)] = LinearAlgebra.I
202
+ # Vᴴ [Block(col, col)] = LinearAlgebra.I
207
203
for (row, col) in zip (emptyrows, emptycols)
208
204
copyto! (@view! (U[Block (row, col)]), LinearAlgebra. I)
209
- copyto! (@view! (Vt [Block (col, col)]), LinearAlgebra. I)
205
+ copyto! (@view! (Vᴴ [Block (col, col)]), LinearAlgebra. I)
210
206
end
211
207
212
- return USVᴴ
208
+ return (U, S, Vᴴ)
213
209
end
214
210
215
211
function MatrixAlgebraKit. svd_full! (
216
- A:: AbstractBlockSparseMatrix , USVᴴ , alg:: BlockPermutedDiagonalAlgorithm
212
+ A:: AbstractBlockSparseMatrix , (U, S, Vᴴ) , alg:: BlockPermutedDiagonalAlgorithm
217
213
)
218
- MatrixAlgebraKit. check_input (svd_full!, A, USVᴴ)
219
- U, S, Vt = USVᴴ
214
+ check_input (svd_full!, A, (U, S, Vᴴ))
220
215
221
216
# do decomposition on each block
222
217
for bI in eachblockstoredindex (A)
223
218
brow, bcol = Tuple (bI)
224
- usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vt [bcol, bcol]))
219
+ usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vᴴ [bcol, bcol]))
225
220
usvᴴ′ = svd_full! (@view! (A[bI]), usvᴴ, alg. alg)
226
221
@assert usvᴴ === usvᴴ′ " svd_full! might not be in-place"
227
222
end
@@ -237,17 +232,17 @@ function MatrixAlgebraKit.svd_full!(
237
232
# Vt[Block(col, col)] = LinearAlgebra.I
238
233
for (row, col) in zip (emptyrows, emptycols)
239
234
copyto! (@view! (U[Block (row, col)]), LinearAlgebra. I)
240
- copyto! (@view! (Vt [Block (col, col)]), LinearAlgebra. I)
235
+ copyto! (@view! (Vᴴ [Block (col, col)]), LinearAlgebra. I)
241
236
end
242
237
243
238
# also handle extra rows/cols
244
239
for i in (length (emptyrows) + 1 ): length (emptycols)
245
- copyto! (@view! (Vt [Block (emptycols[i], emptycols[i])]), LinearAlgebra. I)
240
+ copyto! (@view! (Vᴴ [Block (emptycols[i], emptycols[i])]), LinearAlgebra. I)
246
241
end
247
242
bn = blocksize (A, 2 )
248
243
for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
249
244
copyto! (@view! (U[Block (emptyrows[k], bn + i)]), LinearAlgebra. I)
250
245
end
251
246
252
- return USVᴴ
247
+ return (U, S, Vᴴ)
253
248
end
0 commit comments