Skip to content

Commit 1e60f85

Browse files
committed
Jacobi SVD algorithm
1 parent d7eba9f commit 1e60f85

File tree

3 files changed

+77
-78
lines changed

3 files changed

+77
-78
lines changed

ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgori
1515
YACUSOLVER.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
1616
elseif alg isa CUSOLVER_SVDPolar
1717
YACUSOLVER.Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
18+
elseif alg isa CUSOLVER_Jacobi
19+
YACUSOLVER.gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
1820
# elseif alg isa LAPACK_Bisection
1921
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
2022
# elseif alg isa LAPACK_Jacobi
@@ -54,16 +56,14 @@ function MatrixAlgebraKit.svd_compact!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlg
5456
YACUSOLVER.gesvd!(A, S.diag, U, Vᴴ)
5557
elseif alg isa CUSOLVER_SVDPolar
5658
YACUSOLVER.Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
59+
elseif alg isa CUSOLVER_Jacobi
60+
YACUSOLVER.gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...)
5761
# elseif alg isa LAPACK_DivideAndConquer
5862
# isempty(alg.kwargs) ||
5963
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
6064
# YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
6165
# elseif alg isa LAPACK_Bisection
6266
# YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...)
63-
# elseif alg isa LAPACK_Jacobi
64-
# isempty(alg.kwargs) ||
65-
# throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
66-
# YALAPACK.gesvj!(A, S.diag, U, Vᴴ)
6767
else
6868
throw(ArgumentError("Unsupported SVD algorithm"))
6969
end
@@ -89,6 +89,8 @@ function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm)
8989
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
9090
elseif alg isa CUSOLVER_SVDPolar
9191
YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
92+
elseif alg isa CUSOLVER_Jacobi
93+
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; alg.kwargs...)
9294
# elseif alg isa LAPACK_DivideAndConquer
9395
# isempty(alg.kwargs) ||
9496
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 69 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -171,84 +171,81 @@ function Xgesvdp!(A::StridedCuMatrix{T},
171171
end
172172
173173
# Wrapper for SVD via Jacobi
174-
# for (bname, fname, elty, relty) in
175-
# ((:cusolverDnSgesvdj_bufferSize, :cusolverDnSgesvdj, :Float32, :Float32),
176-
# (:cusolverDnDgesvdj_bufferSize, :cusolverDnDgesvdj, :Float64, :Float64),
177-
# (:cusolverDnCgesvdj_bufferSize, :cusolverDnCgesvdj, :ComplexF32, :Float32),
178-
# (:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64))
179-
# @eval begin
180-
# #! format: off
181-
# function gesvdj!(A::StridedCuMatrix{$elty},
182-
# S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)),
183-
# U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)),
184-
# Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2));
185-
# tol::$relty=eps($relty),
186-
# max_sweeps::Int=100)
187-
# #! format: on
188-
# chkstride1(A, U, Vᴴ, S)
189-
# m, n = size(A)
190-
# minmn = min(m, n)
191-
192-
# if length(U) == 0 && length(Vᴴ) == 0
193-
# jobz = 'N'
194-
# econ = 0
195-
# else
196-
# jobz = 'V'
197-
# size(U, 1) == m ||
198-
# throw(DimensionMismatch("row size mismatch between A and U"))
199-
# size(Vᴴ, 2) == n ||
200-
# throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
201-
# if size(U, 2) == size(Vᴴ, 1) == minmn
202-
# econ = 1
203-
# elseif size(U, 2) == m && size(Vᴴ, 1) == n
204-
# econ = 0
205-
# else
206-
# throw(DimensionMismatch("invalid column size of U or row size of Vᴴ"))
207-
# end
208-
# end
209-
# length(S) == minmn ||
210-
# throw(DimensionMismatch("length mismatch between A and S"))
211-
212-
# if jobz == 'N' # it seems we still need the memory for U and Vᴴ
213-
# U = similar(A, $elty, m, minmn)
214-
# V = similar(A, $elty, n, minmn)
215-
# else
216-
# V = similar(Vᴴ')
217-
# end
218-
# lda = max(1, stride(A, 2))
219-
# ldu = max(1, stride(U, 2))
220-
# ldv = max(1, stride(V, 2))
174+
for (bname, fname, elty, relty) in
175+
((:cusolverDnSgesvdj_bufferSize, :cusolverDnSgesvdj, :Float32, :Float32),
176+
(:cusolverDnDgesvdj_bufferSize, :cusolverDnDgesvdj, :Float64, :Float64),
177+
(:cusolverDnCgesvdj_bufferSize, :cusolverDnCgesvdj, :ComplexF32, :Float32),
178+
(:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64))
179+
@eval begin
180+
#! format: off
181+
function gesvdj!(A::StridedCuMatrix{$elty},
182+
S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)),
183+
U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)),
184+
Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2));
185+
tol::$relty=eps($relty),
186+
max_sweeps::Int=100)
187+
#! format: on
188+
chkstride1(A, U, Vᴴ, S)
189+
m, n = size(A)
190+
minmn = min(m, n)
221191
222-
# params = Ref{gesvdjInfo_t}(C_NULL)
223-
# cusolverDnCreateGesvdjInfo(params)
224-
# cusolverDnXgesvdjSetTolerance(params[], tol)
225-
# cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps)
226-
# dh = dense_handle()
192+
if length(U) == 0 && length(Vᴴ) == 0
193+
jobz = 'N'
194+
econ = 0
195+
else
196+
jobz = 'V'
197+
size(U, 1) == m ||
198+
throw(DimensionMismatch("row size mismatch between A and U"))
199+
size(Vᴴ, 2) == n ||
200+
throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
201+
if size(U, 2) == size(Vᴴ, 1) == minmn
202+
econ = 1
203+
elseif size(U, 2) == m && size(Vᴴ, 1) == n
204+
econ = 0
205+
else
206+
throw(DimensionMismatch("invalid column size of U or row size of Vᴴ"))
207+
end
208+
end
209+
length(S) == minmn ||
210+
throw(DimensionMismatch("length mismatch between A and S"))
227211
228-
# function bufferSize()
229-
# out = Ref{Cint}(0)
230-
# $bname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
231-
# out, params[])
232-
# return out[] * sizeof($elty)
233-
# end
212+
Ṽ = (jobz == 'V') ? similar(Vᴴ') : similar(Vᴴ, (n, minmn))
213+
= (jobz == 'V') ? U : similar(U, (m, minmn))
214+
lda = max(1, stride(A, 2))
215+
ldu = max(1, stride(Ũ, 2))
216+
ldv = max(1, stride(Ṽ, 2))
234217

235-
# with_workspace(dh.workspace_gpu, bufferSize) do buffer
236-
# return $fname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
237-
# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[])
238-
# end
218+
params = Ref{CUSOLVER.gesvdjInfo_t}(C_NULL)
219+
CUSOLVER.cusolverDnCreateGesvdjInfo(params)
220+
CUSOLVER.cusolverDnXgesvdjSetTolerance(params[], tol)
221+
CUSOLVER.cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps)
222+
dh = CUSOLVER.dense_handle()
239223

240-
# info = @allowscalar dh.info[1]
241-
# chkargsok(BlasInt(info))
224+
function bufferSize()
225+
out = Ref{Cint}(0)
226+
CUSOLVER.$bname(dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv,
227+
out, params[])
228+
return out[] * sizeof($elty)
229+
end
242230

243-
# cusolverDnDestroyGesvdjInfo(params[])
231+
CUSOLVER.with_workspace(dh.workspace_gpu, bufferSize) do buffer
232+
return CUSOLVER.$fname(dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv,
233+
buffer, sizeof(buffer) ÷ sizeof($elty), dh.info,
234+
params[])
235+
end
244236

245-
# if jobz != 'N'
246-
# adjoint!(Vᴴ, V)
247-
# end
248-
# return U, S, Vᴴ
249-
# end
250-
# end
251-
# end
237+
info = @allowscalar dh.info[1]
238+
CUSOLVER.chkargsok(BlasInt(info))
239+
240+
CUSOLVER.cusolverDnDestroyGesvdjInfo(params[])
241+
242+
if jobz == 'V'
243+
adjoint!(Vᴴ, Ṽ)
244+
end
245+
return U, S, Vᴴ
246+
end
247+
end
248+
end
252249

253250
# for (jname, bname, fname, elty, relty) in
254251
# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),

test/cuda/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ include("utilities.jl")
1313
m = 54
1414
@testset "size ($m, $n)" for n in (37, m, 63)
1515
k = min(m, n)
16-
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar())
16+
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
1717
@testset "algorithm $alg" for alg in algs
1818
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
1919
minmn = min(m, n)
@@ -49,7 +49,7 @@ end
4949
rng = StableRNG(123)
5050
m = 54
5151
@testset "size ($m, $n)" for n in (37, m, 63)
52-
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar())
52+
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
5353
@testset "algorithm $alg" for alg in algs
5454
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
5555
A = CuArray(randn(rng, T, m, n))

0 commit comments

Comments
 (0)