Skip to content

Commit 23300d1

Browse files
committed
Use testsuite for SVD tests
1 parent 3f1c86a commit 23300d1

File tree

11 files changed

+419
-751
lines changed

11 files changed

+419
-751
lines changed

src/implementations/svd.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ end
152152
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
153153
check_input(svd_compact!, A, USVᴴ, alg)
154154
U, S, Vᴴ = USVᴴ
155+
m, n = size(A)
156+
minmn = min(m, n)
157+
if minmn == 0
158+
one!(U)
159+
zero!(S)
160+
one!(Vᴴ)
161+
return USVᴴ
162+
end
155163

156164
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
157165
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
@@ -181,6 +189,12 @@ end
181189

182190
function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
183191
check_input(svd_vals!, A, S, alg)
192+
m, n = size(A)
193+
minmn = min(m, n)
194+
if minmn == 0
195+
zero!(S)
196+
return S
197+
end
184198
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
185199

186200
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
@@ -256,6 +270,12 @@ function svd_compact!(A, USVᴴ, alg::DiagonalAlgorithm)
256270
end
257271
function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithm)
258272
check_input(svd_vals!, A, S, alg)
273+
m, n = size(A)
274+
minmn = min(m, n)
275+
if minmn == 0
276+
zero!(S)
277+
return S
278+
end
259279
Ad = diagview(A)
260280
S .= abs.(Ad)
261281
sort!(S; rev = true)
@@ -407,6 +427,14 @@ end
407427
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
408428
check_input(svd_compact!, A, USVᴴ, alg)
409429
U, S, Vᴴ = USVᴴ
430+
m, n = size(A)
431+
minmn = min(m, n)
432+
if minmn == 0
433+
one!(U)
434+
zero!(S)
435+
one!(Vᴴ)
436+
return USVᴴ
437+
end
410438

411439
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
412440
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
@@ -431,6 +459,12 @@ _largest(x, y) = abs(x) < abs(y) ? y : x
431459

432460
function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
433461
check_input(svd_vals!, A, S, alg)
462+
m, n = size(A)
463+
minmn = min(m, n)
464+
if minmn == 0
465+
zero!(S)
466+
return S
467+
end
434468
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
435469

436470
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})

test/amd/svd.jl

Lines changed: 0 additions & 157 deletions
This file was deleted.

test/cuda/svd.jl

Lines changed: 0 additions & 170 deletions
This file was deleted.

0 commit comments

Comments
 (0)