Skip to content

Commit eca8520

Browse files
committed
Attempting to wrap randomized SVD
1 parent 4d091cb commit eca8520

File tree

6 files changed

+127
-27
lines changed

6 files changed

+127
-27
lines changed

ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
22
CUSOLVER_SVDPolar,
3-
CUSOLVER_Jacobi}
3+
CUSOLVER_Jacobi,
4+
CUSOLVER_Randomized}
45

56
# CUSOLVER SVD implementation
67
function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm)
@@ -9,6 +10,7 @@ function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgori
910
fill!(S, zero(eltype(S)))
1011
m, n = size(A)
1112
minmn = min(m, n)
13+
k = alg isa CUSOLVER_Randomized ? get(alg.kwargs, :k, min(size(S)...)) : minmn
1214
if alg isa CUSOLVER_QRIteration
1315
isempty(alg.kwargs) ||
1416
throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments"))
@@ -17,18 +19,18 @@ function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgori
1719
YACUSOLVER.Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
1820
elseif alg isa CUSOLVER_Jacobi
1921
YACUSOLVER.gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
20-
# elseif alg isa LAPACK_Bisection
21-
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
22-
# elseif alg isa LAPACK_Jacobi
23-
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
22+
elseif alg isa CUSOLVER_Randomized
23+
YACUSOLVER.Xgesvdr!(A, view(S, 1:k, 1), U, Vᴴ; alg.kwargs...)
24+
# elseif alg isa LAPACK_Bisection
25+
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
2426
else
2527
throw(ArgumentError("Unsupported SVD algorithm"))
2628
end
2729
diagview(S) .= view(S, 1:minmn, 1)
2830
view(S, 2:minmn, 1) .= zero(eltype(S))
2931
# TODO: make this controllable using a `gaugefix` keyword argument
3032
for j in 1:max(m, n)
31-
if j <= minmn
33+
if j <= minmn
3234
u = view(U, :, j)
3335
v = view(Vᴴ, j, :)
3436
s = conj(sign(_argmaxabs(u)))
@@ -58,12 +60,6 @@ function MatrixAlgebraKit.svd_compact!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlg
5860
YACUSOLVER.Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
5961
elseif alg isa CUSOLVER_Jacobi
6062
YACUSOLVER.gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...)
61-
# elseif alg isa LAPACK_DivideAndConquer
62-
# isempty(alg.kwargs) ||
63-
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
64-
# YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
65-
# elseif alg isa LAPACK_Bisection
66-
# YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...)
6763
else
6864
throw(ArgumentError("Unsupported SVD algorithm"))
6965
end
@@ -81,7 +77,8 @@ _argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))
8177
_largest(x, y) = abs(x) < abs(y) ? y : x
8278

8379
function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm)
84-
check_input(svd_vals!, A, S)
80+
# TODO restore this also for randomized
81+
alg isa CUSOLVER_Randomized || check_input(svd_vals!, A, S)
8582
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
8683
if alg isa CUSOLVER_QRIteration
8784
isempty(alg.kwargs) ||
@@ -91,18 +88,10 @@ function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm)
9188
YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
9289
elseif alg isa CUSOLVER_Jacobi
9390
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; alg.kwargs...)
94-
# elseif alg isa LAPACK_DivideAndConquer
95-
# isempty(alg.kwargs) ||
96-
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
97-
# YALAPACK.gesdd!(A, S, U, Vᴴ)
98-
# elseif alg isa LAPACK_Bisection
99-
# YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...)
100-
# elseif alg isa LAPACK_Jacobi
101-
# isempty(alg.kwargs) ||
102-
# throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
103-
# YALAPACK.gesvj!(A, S, U, Vᴴ)
91+
elseif alg isa CUSOLVER_Randomized
92+
YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; alg.kwargs...)
10493
else
10594
throw(ArgumentError("Unsupported SVD algorithm"))
10695
end
10796
return S
108-
end
97+
end

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,66 @@ for (bname, fname, elty, relty) in
247247
end
248248
end
249249

250+
# Wrapper for randomized SVD
251+
function Xgesvdr!(A::StridedCuMatrix{T},
252+
S::StridedCuVector=similar(A, real(T), min(size(A)...)),
253+
U::StridedCuMatrix{T}=similar(A, T, size(A, 1), min(size(A)...)),
254+
Vᴴ::StridedCuMatrix{T}=similar(A, T, min(size(A)...), size(A, 2));
255+
k::Int=length(S),
256+
p::Int=min(size(A)...)-k-1,
257+
niters::Int=1) where {T<:BlasFloat}
258+
chkstride1(A, U, S, Vᴴ)
259+
m, n = size(A)
260+
minmn = min(m, n)
261+
jobu = length(U) == 0 ? 'N' : 'S'
262+
jobv = length(Vᴴ) == 0 ? 'N' : 'S'
263+
k = min(size(S)...)
264+
R = eltype(S)
265+
k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)"))
266+
k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)"))
267+
R == real(T) ||
268+
throw(ArgumentError("S does not have the matching real `eltype` of A"))
269+
270+
= similar(Vᴴ, (n, n))
271+
= (size(U) == (m, m)) ? U : similar(U, (m, m))
272+
lda = max(1, stride(A, 2))
273+
ldu = max(1, stride(Ũ, 2))
274+
ldv = max(1, stride(Ṽ, 2))
275+
params = CUSOLVER.CuSolverParameters()
276+
dh = CUSOLVER.dense_handle()
277+
278+
function bufferSize()
279+
out_cpu = Ref{Csize_t}(0)
280+
out_gpu = Ref{Csize_t}(0)
281+
CUSOLVER.cusolverDnXgesvdr_bufferSize(dh, params, jobu, jobv, m, n, k, p, niters,
282+
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
283+
T, out_gpu, out_cpu)
284+
285+
return out_gpu[], out_cpu[]
286+
end
287+
CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
288+
bufferSize()...) do buffer_gpu, buffer_cpu
289+
return CUSOLVER.cusolverDnXgesvdr(dh, params, jobu, jobv, m, n, k, p, niters,
290+
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
291+
T, buffer_gpu, sizeof(buffer_gpu),
292+
buffer_cpu, sizeof(buffer_cpu),
293+
dh.info)
294+
end
295+
296+
flag = @allowscalar dh.info[1]
297+
CUSOLVER.chklapackerror(BlasInt(flag))
298+
if!== U && length(U) > 0
299+
U .= view(Ũ, 1:m, 1:size(U, 2))
300+
end
301+
if length(Vᴴ) > 0
302+
Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n)
303+
end
304+
Ũ !== U && CUDA.unsafe_free!(Ũ)
305+
CUDA.unsafe_free!(Ṽ)
306+
307+
return S, U, Vᴴ
308+
end
309+
250310
# for (jname, bname, fname, elty, relty) in
251311
# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
252312
# (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64),
@@ -591,4 +651,4 @@ end
591651
# end
592652
# end
593653

594-
end
654+
end

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3131
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3232
LAPACK_DivideAndConquer, LAPACK_Jacobi,
3333
LQViaTransposedQR,
34-
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi
34+
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized
3535
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
3636

3737
VERSION >= v"1.11.0-DEV.469" &&

src/implementations/svd.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat
6666
return initialize_output(svd_compact!, A, alg.alg)
6767
end
6868

69+
6970
# Implementation
7071
# --------------
7172
function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)

src/interface/decompositions.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,16 @@ Algorithm type to denote the CUSOLVER driver for computing the singular value de
150150
a general matrix using the Jacobi algorithm.
151151
"""
152152
@algdef CUSOLVER_Jacobi
153+
154+
"""
155+
CUSOLVER_Randomized(; p, niters)
156+
157+
Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of
158+
a general matrix using the randomized SVD algorithm.
159+
160+
!!! note
161+
Randomized SVD cannot compute all singular values of the input matrix `A`, only the first `k` where
162+
`k < min(m, n)`. The remainder are used for oversampling. See the [CUSOLVER documentation](https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdnxgesvdr)
163+
for more information.
164+
"""
165+
@algdef CUSOLVER_Randomized

test/cuda/svd.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,49 @@ end
7676
@test isapproxone(Vᴴ' * Vᴴ)
7777
@test all(isposdef, diagview(S))
7878

79-
Sc = similar(A, real(T), min(m, n))
79+
minmn = min(m, n)
80+
Sc = similar(A, real(T), minmn)
8081
Sc2 = svd_vals!(copy!(Ac, A), Sc, alg)
8182
@test Sc === Sc2
8283
@test CuArray(diagview(S)) Sc
8384
# CuArray is necessary because norm of CuArray view with non-unit step is broken
8485
end
86+
k = min(m, n) - 20
87+
p = min(m, n) - k - 1
88+
algs = (CUSOLVER_Randomized(; k=k, p=p, niters=100),)
89+
@testset "algorithm $alg" for alg in algs
90+
A = CuArray(randn(rng, T, m, n))
91+
Uref, Sref, Vᴴref = svd_full(A, CUSOLVER_SVDPolar())
92+
U, S, Vᴴ = svd_full(A; alg)
93+
@test U isa CuMatrix{T} && size(U) == (m, m)
94+
@test S isa CuMatrix{real(T)} && size(S) == (m, n)
95+
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (n, n)
96+
for col in 1:k
97+
@test view(collect(U), :, col) view(collect(Uref), :, col)
98+
@test view(collect(Vᴴ), col, :) view(collect(Vᴴref), col, :)
99+
end
100+
@test all(isposdef, view(diagview(S), 1:k))
101+
@test view(CuArray(diagview(S)), 1:k) view(CuArray(diagview(Sref)), 1:k)
102+
103+
Ac = similar(A)
104+
U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg)
105+
@test U2 === U
106+
@test S2 === S
107+
@test V2ᴴ === Vᴴ
108+
for col in 1:k
109+
@test view(collect(U), :, col) view(collect(Uref), :, col)
110+
@test view(collect(Vᴴ), col, :) view(collect(Vᴴref), col, :)
111+
end
112+
@test all(isposdef, view(diagview(S), 1:k))
113+
@test view(CuArray(diagview(S2)), 1:k) view(CuArray(diagview(Sref)), 1:k)
114+
115+
Sc = similar(A, real(T), k)
116+
Sc2 = svd_vals!(copy!(Ac, A), Sc, alg)
117+
@test Sc === Sc2
118+
@test view(Sc, 1:k) view(CuArray(diagview(Sref)), 1:k)
119+
@test view(CuArray(diagview(S)), 1:k) Sc
120+
# CuArray is necessary because norm of CuArray view with non-unit step is broken
121+
end
85122
end
86123
end
87124

0 commit comments

Comments
 (0)