@@ -283,32 +283,27 @@ end
283
283
284
284
# Xgesvdp
285
285
function Xgesvdp! (jobz:: Char , econ:: Int , A:: StridedCuMatrix{T} ) where {T <: BlasFloat }
286
+ econ in (0 , 1 ) || throw (ArgumentError (" econ is incorrect. The values accepted are 0 and 1." ))
287
+ return Xgesvdp! (jobz, Bool (econ), A)
288
+ end
289
+
290
+ function Xgesvdp! (jobz:: Char , econ:: Bool , A:: StridedCuMatrix{T} ) where {T <: BlasFloat }
286
291
m, n = size (A)
287
292
p = min (m, n)
288
293
R = real (T)
289
- econ ∈ (0 , 1 ) || throw (ArgumentError (" econ is incorrect. The values accepted are 0 and 1." ))
290
- U = if jobz == ' V' && econ == 1
291
- CuMatrix {T} (undef, m, p)
292
- elseif jobz == ' V' && econ == 0
293
- CuMatrix {T} (undef, m, m)
294
- elseif jobz == ' N'
295
- CU_NULL
294
+ jobz in (' N' , ' V' ) || throw (ArgumentError (" jobz is incorrect. The values accepted are 'V' and 'N'." ))
295
+
296
+ if econ
297
+ U = CuMatrix {T} (undef, m, p)
298
+ V = CuMatrix {T} (undef, n, p)
296
299
else
297
- throw (ArgumentError (" jobz is incorrect. The values accepted are 'V' and 'N'." ))
300
+ U = CuMatrix {T} (undef, m, m)
301
+ V = CuMatrix {T} (undef, n, n)
298
302
end
299
303
Σ = CuVector {R} (undef, p)
300
- V = if jobz == ' V' && econ == 1
301
- CuMatrix {T} (undef, n, p)
302
- elseif jobz == ' V' && econ == 0
303
- CuMatrix {T} (undef, n, n)
304
- elseif jobz == ' N'
305
- CU_NULL
306
- else
307
- throw (ArgumentError (" jobz is incorrect. The values accepted are 'V' and 'N'." ))
308
- end
309
304
lda = max (1 , stride (A, 2 ))
310
- ldu = U == CU_NULL ? 1 : max (1 , stride (U, 2 ))
311
- ldv = V == CU_NULL ? 1 : max (1 , stride (V, 2 ))
305
+ ldu = max (1 , stride (U, 2 ))
306
+ ldv = max (1 , stride (V, 2 ))
312
307
h_err_sigma = Ref {Cdouble} (0 )
313
308
params = CuSolverParameters ()
314
309
dh = dense_handle ()
@@ -331,6 +326,8 @@ function Xgesvdp!(jobz::Char, econ::Int, A::StridedCuMatrix{T}) where {T <: Blas
331
326
flag = @allowscalar dh. info[1 ]
332
327
chklapackerror (flag |> BlasInt)
333
328
if jobz == ' N'
329
+ unsafe_free! (U)
330
+ unsafe_free! (V)
334
331
return Σ, h_err_sigma[]
335
332
elseif jobz == ' V'
336
333
return U, Σ, V, h_err_sigma[]
0 commit comments