Skip to content

Commit 65264b8

Browse files
committed
Working svd_trunc
1 parent e8ca7f1 commit 65264b8

File tree

2 files changed

+6
-38
lines changed

2 files changed

+6
-38
lines changed

src/implementations/svd.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
286286
throw(ArgumentError("Unsupported SVD algorithm"))
287287
end
288288
# TODO: make this controllable using a `gaugefix` keyword argument
289-
for j in 1:size(U, 2)
289+
minmn = min(size(A)...)
290+
for j in 1:minmn # make this more general to account for the larger U in CUSOVLER_Randomized
290291
u = view(U, :, j)
291292
v = view(Vᴴ, j, :)
292293
s = conj(sign(_argmaxabs(u)))

test/cuda/svd.jl

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using MatrixAlgebraKit
22
using MatrixAlgebraKit: diagview
3-
using LinearAlgebra: Diagonal, isposdef
3+
using LinearAlgebra: Diagonal, isposdef, opnorm
44
using Test
55
using TestExtras
66
using StableRNGs
@@ -96,7 +96,7 @@ end
9696
p = min(m, n) - k - 1
9797
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k=k, p=p, niters=100),)
9898
@testset "algorithm $alg" for alg in algs
99-
#n > m && alg isa CUSOLVER_Jacobi && continue # not supported
99+
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
100100
hA = randn(rng, T, m, n)
101101
S₀ = svd_vals(hA)
102102
A = CuArray(hA)
@@ -105,7 +105,7 @@ end
105105

106106
U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r))
107107
@test length(S1.diag) == r
108-
@test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) S₀[r + 1]
108+
@test opnorm(A - U1 * S1 * V1ᴴ) S₀[r + 1]
109109

110110
if !(alg isa CUSOLVER_Randomized)
111111
s = 1 + sqrt(eps(real(T)))
@@ -114,42 +114,9 @@ end
114114
U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(s * S₀[r + 1]))
115115
@test length(S2.diag) == r
116116
@test U1 U2
117-
@test S1 S2
117+
@test parent(S1) parent(S2)
118118
@test V1ᴴ V2ᴴ
119119
end
120-
121-
#=A = CuArray(randn(rng, T, m, n))
122-
Uref, Sref, Vᴴref = svd_full(A, CUSOLVER_SVDPolar())
123-
U, S, Vᴴ = svd_full(A; alg)
124-
@test U isa CuMatrix{T} && size(U) == (m, m)
125-
@test S isa CuMatrix{real(T)} && size(S) == (m, n)
126-
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (n, n)
127-
for col in 1:k
128-
@test view(collect(U), :, col) ≈ view(collect(Uref), :, col)
129-
@test view(collect(Vᴴ), col, :) ≈ view(collect(Vᴴref), col, :)
130-
end
131-
@test all(isposdef, view(diagview(S), 1:k))
132-
@test view(CuArray(diagview(S)), 1:k) ≈ view(CuArray(diagview(Sref)), 1:k)
133-
134-
Ac = similar(A)
135-
U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg)
136-
@test U2 === U
137-
@test S2 === S
138-
@test V2ᴴ === Vᴴ
139-
for col in 1:k
140-
@test view(collect(U), :, col) ≈ view(collect(Uref), :, col)
141-
@test view(collect(Vᴴ), col, :) ≈ view(collect(Vᴴref), col, :)
142-
end
143-
@test all(isposdef, view(diagview(S), 1:k))
144-
@test view(CuArray(diagview(S2)), 1:k) ≈ view(CuArray(diagview(Sref)), 1:k)
145-
146-
Sc = similar(A, real(T), k)
147-
Sc2 = svd_vals!(copy!(Ac, A), Sc, alg)
148-
@test Sc === Sc2
149-
@test view(Sc, 1:k) ≈ view(CuArray(diagview(Sref)), 1:k)
150-
@test view(CuArray(diagview(S)), 1:k) ≈ Sc
151-
# CuArray is necessary because norm of CuArray view with non-unit step is broken
152-
=#
153120
end
154121
end
155122
end

0 commit comments

Comments
 (0)