Skip to content

Commit 4df87f4

Browse files
committed
Check inputs
1 parent 5146831 commit 4df87f4

File tree

2 files changed

+44
-27
lines changed

2 files changed

+44
-27
lines changed

src/factorizations/eig.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using LinearAlgebra: LinearAlgebra, Diagonal
44
using MatrixAlgebraKit:
55
MatrixAlgebraKit,
66
TruncationStrategy,
7+
check_input,
78
default_eig_algorithm,
89
default_eigh_algorithm,
910
diagview,
@@ -24,6 +25,26 @@ for f in [:default_eig_algorithm, :default_eigh_algorithm]
2425
end
2526
end
2627

28+
function MatrixAlgebraKit.check_input(
29+
::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V)
30+
)
31+
@assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
32+
@assert eltype(V) === eltype(D) === complex(eltype(A))
33+
@assert axes(A, 1) == axes(A, 2)
34+
@assert axes(A) == axes(D) == axes(V)
35+
return nothing
36+
end
37+
function MatrixAlgebraKit.check_input(
38+
::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V)
39+
)
40+
@assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
41+
@assert eltype(V) === eltype(A)
42+
@assert eltype(D) === real(eltype(A))
43+
@assert axes(A, 1) == axes(A, 2)
44+
@assert axes(A) == axes(D) == axes(V)
45+
return nothing
46+
end
47+
2748
for f in [:eig_full!, :eigh_full!]
2849
@eval begin
2950
function MatrixAlgebraKit.initialize_output(
@@ -37,6 +58,7 @@ for f in [:eig_full!, :eigh_full!]
3758
function MatrixAlgebraKit.$f(
3859
A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
3960
)
61+
check_input($f, A, (D, V))
4062
for I in eachstoredblockdiagindex(A)
4163
D[I], V[I] = $f(@view(A[I]), alg.alg)
4264
end

src/factorizations/svd.jl

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using MatrixAlgebraKit: MatrixAlgebraKit, default_svd_algorithm, svd_compact!, svd_full!
1+
using MatrixAlgebraKit:
2+
MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!
23

34
"""
45
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)
@@ -152,45 +153,40 @@ function MatrixAlgebraKit.initialize_output(
152153
end
153154

154155
function MatrixAlgebraKit.check_input(
155-
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, USVᴴ
156+
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ)
156157
)
157-
U, S, Vt = USVᴴ
158158
@assert isa(U, AbstractBlockSparseMatrix) &&
159159
isa(S, AbstractBlockSparseMatrix) &&
160-
isa(Vt, AbstractBlockSparseMatrix)
161-
@assert eltype(A) == eltype(U) == eltype(Vt)
160+
isa(Vᴴ, AbstractBlockSparseMatrix)
161+
@assert eltype(A) == eltype(U) == eltype(Vᴴ)
162162
@assert real(eltype(A)) == eltype(S)
163-
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vt, 2)
163+
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 2)
164164
@assert axes(S, 1) == axes(S, 2)
165-
166165
return nothing
167166
end
168167

169168
function MatrixAlgebraKit.check_input(
170-
::typeof(svd_full!), A::AbstractBlockSparseMatrix, USVᴴ
169+
::typeof(svd_full!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ)
171170
)
172-
U, S, Vt = USVᴴ
173171
@assert isa(U, AbstractBlockSparseMatrix) &&
174172
isa(S, AbstractBlockSparseMatrix) &&
175-
isa(Vt, AbstractBlockSparseMatrix)
176-
@assert eltype(A) == eltype(U) == eltype(Vt)
173+
isa(Vᴴ, AbstractBlockSparseMatrix)
174+
@assert eltype(A) == eltype(U) == eltype(Vᴴ)
177175
@assert real(eltype(A)) == eltype(S)
178-
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vt, 1) == axes(Vt, 2)
176+
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 1) == axes(Vᴴ, 2)
179177
@assert axes(S, 2) == axes(A, 2)
180-
181178
return nothing
182179
end
183180

184181
function MatrixAlgebraKit.svd_compact!(
185-
A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm
182+
A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm
186183
)
187-
MatrixAlgebraKit.check_input(svd_compact!, A, USVᴴ)
188-
U, S, Vt = USVᴴ
184+
check_input(svd_compact!, A, (U, S, Vᴴ))
189185

190186
# do decomposition on each block
191187
for bI in eachblockstoredindex(A)
192188
brow, bcol = Tuple(bI)
193-
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt[bcol, bcol]))
189+
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol]))
194190
usvᴴ′ = svd_compact!(@view!(A[bI]), usvᴴ, alg.alg)
195191
@assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place"
196192
end
@@ -203,25 +199,24 @@ function MatrixAlgebraKit.svd_compact!(
203199
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
204200
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
205201
# U[Block(row, col)] = LinearAlgebra.I
206-
# Vt[Block(col, col)] = LinearAlgebra.I
202+
# Vᴴ[Block(col, col)] = LinearAlgebra.I
207203
for (row, col) in zip(emptyrows, emptycols)
208204
copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I)
209-
copyto!(@view!(Vt[Block(col, col)]), LinearAlgebra.I)
205+
copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I)
210206
end
211207

212-
return USVᴴ
208+
return (U, S, Vᴴ)
213209
end
214210

215211
function MatrixAlgebraKit.svd_full!(
216-
A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm
212+
A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm
217213
)
218-
MatrixAlgebraKit.check_input(svd_full!, A, USVᴴ)
219-
U, S, Vt = USVᴴ
214+
check_input(svd_full!, A, (U, S, Vᴴ))
220215

221216
# do decomposition on each block
222217
for bI in eachblockstoredindex(A)
223218
brow, bcol = Tuple(bI)
224-
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt[bcol, bcol]))
219+
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol]))
225220
usvᴴ′ = svd_full!(@view!(A[bI]), usvᴴ, alg.alg)
226221
@assert usvᴴ === usvᴴ′ "svd_full! might not be in-place"
227222
end
@@ -237,17 +232,17 @@ function MatrixAlgebraKit.svd_full!(
237232
# Vt[Block(col, col)] = LinearAlgebra.I
238233
for (row, col) in zip(emptyrows, emptycols)
239234
copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I)
240-
copyto!(@view!(Vt[Block(col, col)]), LinearAlgebra.I)
235+
copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I)
241236
end
242237

243238
# also handle extra rows/cols
244239
for i in (length(emptyrows) + 1):length(emptycols)
245-
copyto!(@view!(Vt[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I)
240+
copyto!(@view!(Vᴴ[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I)
246241
end
247242
bn = blocksize(A, 2)
248243
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
249244
copyto!(@view!(U[Block(emptyrows[k], bn + i)]), LinearAlgebra.I)
250245
end
251246

252-
return USVᴴ
247+
return (U, S, Vᴴ)
253248
end

0 commit comments

Comments
 (0)