@@ -242,11 +242,70 @@ for (bname, fname, elty, relty) in
242242 if jobz == ' V'
243243 adjoint!(Vᴴ, Ṽ)
244244 end
245- return U, S , Vᴴ
245+ return S, U , Vᴴ
246246 end
247247 end
248248end
249249
250+ # Wrapper for randomized SVD
251+ function Xgesvdr!(A:: StridedCuMatrix{T} ,
252+ S:: StridedCuVector = similar(A, real(T), min(size(A). .. )),
253+ U:: StridedCuMatrix{T} = similar(A, T, size(A, 1 ), min(size(A). .. )),
254+ Vᴴ:: StridedCuMatrix{T} = similar(A, T, min(size(A). .. ), size(A, 2 ));
255+ k:: Int = length(S),
256+ p:: Int = min(size(A). .. )- k- 1 ,
257+ niters:: Int = 1 ) where {T<: BlasFloat }
258+ chkstride1(A, U, S, Vᴴ)
259+ m, n = size(A)
260+ minmn = min(m, n)
261+ jobu = length(U) == 0 ? ' N' : ' S'
262+ jobv = length(Vᴴ) == 0 ? ' N' : ' S'
263+ R = eltype(S)
264+ k < minmn || throw(DimensionMismatch(" length of S ($k ) must be less than the smaller dimension of A ($minmn )" ))
265+ k + p < minmn || throw(DimensionMismatch(" length of S ($k ) plus oversampling ($p ) must be less than the smaller dimension of A ($minmn )" ))
266+ R == real(T) ||
267+ throw(ArgumentError(" S does not have the matching real `eltype` of A" ))
268+
269+ Ṽ = similar(Vᴴ, (n, n))
270+ Ũ = (size(U) == (m, m)) ? U : similar(U, (m, m))
271+ lda = max(1 , stride(A, 2 ))
272+ ldu = max(1 , stride(Ũ, 2 ))
273+ ldv = max(1 , stride(Ṽ, 2 ))
274+ params = CUSOLVER. CuSolverParameters()
275+ dh = CUSOLVER. dense_handle()
276+
277+ function bufferSize()
278+ out_cpu = Ref{Csize_t}(0 )
279+ out_gpu = Ref{Csize_t}(0 )
280+ CUSOLVER. cusolverDnXgesvdr_bufferSize(dh, params, jobu, jobv, m, n, k, p, niters,
281+ T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
282+ T, out_gpu, out_cpu)
283+
284+ return out_gpu[], out_cpu[]
285+ end
286+ CUSOLVER. with_workspaces(dh. workspace_gpu, dh. workspace_cpu,
287+ bufferSize(). .. ) do buffer_gpu, buffer_cpu
288+ return CUSOLVER. cusolverDnXgesvdr(dh, params, jobu, jobv, m, n, k, p, niters,
289+ T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
290+ T, buffer_gpu, sizeof(buffer_gpu),
291+ buffer_cpu, sizeof(buffer_cpu),
292+ dh. info)
293+ end
294+
295+ flag = @allowscalar dh. info[1 ]
296+ CUSOLVER. chklapackerror(BlasInt(flag))
297+ if Ũ != = U && length(U) > 0
298+ U .= view(Ũ, 1 : m, 1 : size(U, 2 ))
299+ end
300+ if length(Vᴴ) > 0
301+ Vᴴ .= view(Ṽ' , 1:size(Vᴴ, 1), 1:n)
302+ end
303+ Ũ !== U && CUDA.unsafe_free!(Ũ)
304+ CUDA.unsafe_free!(Ṽ)
305+
306+ return S, U, Vᴴ
307+ end
308+
250309# for (jname, bname, fname, elty, relty) in
251310# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
252311# (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64),
0 commit comments