1- using MatrixAlgebraKit: MatrixAlgebraKit, default_svd_algorithm, svd_compact!, svd_full!
1+ using MatrixAlgebraKit:
2+ MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!
23
34"""
45 BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)
@@ -152,45 +153,40 @@ function MatrixAlgebraKit.initialize_output(
152153end
153154
154155function MatrixAlgebraKit. check_input(
155- :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , USVᴴ
156+ :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ)
156157)
157- U, S, Vt = USVᴴ
158158 @assert isa(U, AbstractBlockSparseMatrix) &&
159159 isa(S, AbstractBlockSparseMatrix) &&
160- isa(Vt , AbstractBlockSparseMatrix)
161- @assert eltype(A) == eltype(U) == eltype(Vt )
160+ isa(Vᴴ , AbstractBlockSparseMatrix)
161+ @assert eltype(A) == eltype(U) == eltype(Vᴴ )
162162 @assert real(eltype(A)) == eltype(S)
163- @assert axes(A, 1 ) == axes(U, 1 ) && axes(A, 2 ) == axes(Vt , 2 )
163+ @assert axes(A, 1 ) == axes(U, 1 ) && axes(A, 2 ) == axes(Vᴴ , 2 )
164164 @assert axes(S, 1 ) == axes(S, 2 )
165-
166165 return nothing
167166end
168167
169168function MatrixAlgebraKit. check_input(
170- :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , USVᴴ
169+ :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ)
171170)
172- U, S, Vt = USVᴴ
173171 @assert isa(U, AbstractBlockSparseMatrix) &&
174172 isa(S, AbstractBlockSparseMatrix) &&
175- isa(Vt , AbstractBlockSparseMatrix)
176- @assert eltype(A) == eltype(U) == eltype(Vt )
173+ isa(Vᴴ , AbstractBlockSparseMatrix)
174+ @assert eltype(A) == eltype(U) == eltype(Vᴴ )
177175 @assert real(eltype(A)) == eltype(S)
178- @assert axes(A, 1 ) == axes(U, 1 ) && axes(A, 2 ) == axes(Vt , 1 ) == axes(Vt , 2 )
176+ @assert axes(A, 1 ) == axes(U, 1 ) && axes(A, 2 ) == axes(Vᴴ , 1 ) == axes(Vᴴ , 2 )
179177 @assert axes(S, 2 ) == axes(A, 2 )
180-
181178 return nothing
182179end
183180
184181function MatrixAlgebraKit. svd_compact!(
185- A:: AbstractBlockSparseMatrix , USVᴴ , alg:: BlockPermutedDiagonalAlgorithm
182+ A:: AbstractBlockSparseMatrix , (U, S, Vᴴ) , alg:: BlockPermutedDiagonalAlgorithm
186183)
187- MatrixAlgebraKit. check_input(svd_compact!, A, USVᴴ)
188- U, S, Vt = USVᴴ
184+ check_input(svd_compact!, A, (U, S, Vᴴ))
189185
190186 # do decomposition on each block
191187 for bI in eachblockstoredindex(A)
192188 brow, bcol = Tuple(bI)
193- usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt [bcol, bcol]))
189+ usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ [bcol, bcol]))
194190 usvᴴ′ = svd_compact!(@view!(A[bI]), usvᴴ, alg. alg)
195191 @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
196192 end
@@ -203,25 +199,24 @@ function MatrixAlgebraKit.svd_compact!(
203199 emptycols = setdiff(1 : blocksize(A, 2 ), bcolIs)
204200 # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
205201 # U[Block(row, col)] = LinearAlgebra.I
206- # Vt [Block(col, col)] = LinearAlgebra.I
202+ # Vᴴ [Block(col, col)] = LinearAlgebra.I
207203 for (row, col) in zip(emptyrows, emptycols)
208204 copyto!(@view!(U[Block(row, col)]), LinearAlgebra. I)
209- copyto!(@view!(Vt [Block(col, col)]), LinearAlgebra. I)
205+ copyto!(@view!(Vᴴ [Block(col, col)]), LinearAlgebra. I)
210206 end
211207
212- return USVᴴ
208+ return (U, S, Vᴴ)
213209end
214210
215211function MatrixAlgebraKit. svd_full!(
216- A:: AbstractBlockSparseMatrix , USVᴴ , alg:: BlockPermutedDiagonalAlgorithm
212+ A:: AbstractBlockSparseMatrix , (U, S, Vᴴ) , alg:: BlockPermutedDiagonalAlgorithm
217213)
218- MatrixAlgebraKit. check_input(svd_full!, A, USVᴴ)
219- U, S, Vt = USVᴴ
214+ check_input(svd_full!, A, (U, S, Vᴴ))
220215
221216 # do decomposition on each block
222217 for bI in eachblockstoredindex(A)
223218 brow, bcol = Tuple(bI)
224- usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt [bcol, bcol]))
219+ usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ [bcol, bcol]))
225220 usvᴴ′ = svd_full!(@view!(A[bI]), usvᴴ, alg. alg)
226221 @assert usvᴴ === usvᴴ′ " svd_full! might not be in-place"
227222 end
@@ -237,17 +232,17 @@ function MatrixAlgebraKit.svd_full!(
237232 # Vt[Block(col, col)] = LinearAlgebra.I
238233 for (row, col) in zip(emptyrows, emptycols)
239234 copyto!(@view!(U[Block(row, col)]), LinearAlgebra. I)
240- copyto!(@view!(Vt [Block(col, col)]), LinearAlgebra. I)
235+ copyto!(@view!(Vᴴ [Block(col, col)]), LinearAlgebra. I)
241236 end
242237
243238 # also handle extra rows/cols
244239 for i in (length(emptyrows) + 1 ): length(emptycols)
245- copyto!(@view!(Vt [Block(emptycols[i], emptycols[i])]), LinearAlgebra. I)
240+ copyto!(@view!(Vᴴ [Block(emptycols[i], emptycols[i])]), LinearAlgebra. I)
246241 end
247242 bn = blocksize(A, 2 )
248243 for (i, k) in enumerate((length(emptycols) + 1 ): length(emptyrows))
249244 copyto!(@view!(U[Block(emptyrows[k], bn + i)]), LinearAlgebra. I)
250245 end
251246
252- return USVᴴ
247+ return (U, S, Vᴴ)
253248end
0 commit comments