Skip to content

Commit 5645ddc

Browse files
amontoisonmaleadt
andauthored
[CUSOLVER] Fix gesvdp! when only requesting singular values (#2763)
Co-authored-by: Tim Besard <[email protected]>
1 parent 2807156 commit 5645ddc

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

lib/cusolver/dense_generic.jl

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -283,32 +283,27 @@ end
283283

284284
# Xgesvdp
285285
function Xgesvdp!(jobz::Char, econ::Int, A::StridedCuMatrix{T}) where {T <: BlasFloat}
286+
econ in (0, 1) || throw(ArgumentError("econ is incorrect. The values accepted are 0 and 1."))
287+
return Xgesvdp!(jobz, Bool(econ), A)
288+
end
289+
290+
function Xgesvdp!(jobz::Char, econ::Bool, A::StridedCuMatrix{T}) where {T <: BlasFloat}
286291
m, n = size(A)
287292
p = min(m, n)
288293
R = real(T)
289-
econ (0, 1) || throw(ArgumentError("econ is incorrect. The values accepted are 0 and 1."))
290-
U = if jobz == 'V' && econ == 1
291-
CuMatrix{T}(undef, m, p)
292-
elseif jobz == 'V' && econ == 0
293-
CuMatrix{T}(undef, m, m)
294-
elseif jobz == 'N'
295-
CU_NULL
294+
jobz in ('N', 'V') || throw(ArgumentError("jobz is incorrect. The values accepted are 'V' and 'N'."))
295+
296+
if econ
297+
U = CuMatrix{T}(undef, m, p)
298+
V = CuMatrix{T}(undef, n, p)
296299
else
297-
throw(ArgumentError("jobz is incorrect. The values accepted are 'V' and 'N'."))
300+
U = CuMatrix{T}(undef, m, m)
301+
V = CuMatrix{T}(undef, n, n)
298302
end
299303
Σ = CuVector{R}(undef, p)
300-
V = if jobz == 'V' && econ == 1
301-
CuMatrix{T}(undef, n, p)
302-
elseif jobz == 'V' && econ == 0
303-
CuMatrix{T}(undef, n, n)
304-
elseif jobz == 'N'
305-
CU_NULL
306-
else
307-
throw(ArgumentError("jobz is incorrect. The values accepted are 'V' and 'N'."))
308-
end
309304
lda = max(1, stride(A, 2))
310-
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
311-
ldv = V == CU_NULL ? 1 : max(1, stride(V, 2))
305+
ldu = max(1, stride(U, 2))
306+
ldv = max(1, stride(V, 2))
312307
h_err_sigma = Ref{Cdouble}(0)
313308
params = CuSolverParameters()
314309
dh = dense_handle()
@@ -331,6 +326,8 @@ function Xgesvdp!(jobz::Char, econ::Int, A::StridedCuMatrix{T}) where {T <: Blas
331326
flag = @allowscalar dh.info[1]
332327
chklapackerror(flag |> BlasInt)
333328
if jobz == 'N'
329+
unsafe_free!(U)
330+
unsafe_free!(V)
334331
return Σ, h_err_sigma[]
335332
elseif jobz == 'V'
336333
return U, Σ, V, h_err_sigma[]

test/libraries/cusolver/dense_generic.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ p = 5
206206
end
207207

208208
@testset "gesvdp!" begin
209+
# nrows > ncols
209210
A = rand(elty,m,n)
210211
d_A = CuMatrix(A)
211212
U, Σ, V, err_sigma = CUSOLVER.Xgesvdp!('V', 0, d_A)
@@ -215,6 +216,15 @@ p = 5
215216
U, Σ, V, err_sigma = CUSOLVER.Xgesvdp!('V', 1, d_A)
216217
@test A collect(U) * Diagonal(collect(Σ)) * collect(V)'
217218

219+
d_A = CuMatrix(A)
220+
Σ2, err_sigma = CUSOLVER.Xgesvdp!('N', 0, d_A)
221+
@test collect(Σ) collect(Σ2)
222+
223+
d_A = CuMatrix(A)
224+
Σ3, err_sigma = CUSOLVER.Xgesvdp!('N', 1, d_A)
225+
@test collect(Σ) collect(Σ3)
226+
227+
# nrows < ncols
218228
A = rand(elty,n,m)
219229
d_A = CuMatrix(A)
220230
U, Σ, V, err_sigma = CUSOLVER.Xgesvdp!('V', 0, d_A)
@@ -223,6 +233,14 @@ p = 5
223233
d_A = CuMatrix(A)
224234
U, Σ, V, err_sigma = CUSOLVER.Xgesvdp!('V', 1, d_A)
225235
@test A collect(U) * Diagonal(collect(Σ)) * collect(V)'
236+
237+
d_A = CuMatrix(A)
238+
Σ2, err_sigma = CUSOLVER.Xgesvdp!('N', 0, d_A)
239+
@test collect(Σ) collect(Σ2)
240+
241+
d_A = CuMatrix(A)
242+
Σ3, err_sigma = CUSOLVER.Xgesvdp!('N', 1, d_A)
243+
@test collect(Σ) collect(Σ3)
226244
end
227245

228246
@testset "gesvdr!" begin

0 commit comments

Comments
 (0)