@@ -171,84 +171,81 @@ function Xgesvdp!(A::StridedCuMatrix{T},
171171end
172172
173173# Wrapper for SVD via Jacobi
174- # for (bname, fname, elty, relty) in
175- # ((:cusolverDnSgesvdj_bufferSize, :cusolverDnSgesvdj, :Float32, :Float32),
176- # (:cusolverDnDgesvdj_bufferSize, :cusolverDnDgesvdj, :Float64, :Float64),
177- # (:cusolverDnCgesvdj_bufferSize, :cusolverDnCgesvdj, :ComplexF32, :Float32),
178- # (:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64))
179- # @eval begin
180- # #! format: off
181- # function gesvdj!(A::StridedCuMatrix{$elty},
182- # S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)),
183- # U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)),
184- # Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2));
185- # tol::$relty=eps($relty),
186- # max_sweeps::Int=100)
187- # #! format: on
188- # chkstride1(A, U, Vᴴ, S)
189- # m, n = size(A)
190- # minmn = min(m, n)
191-
192- # if length(U) == 0 && length(Vᴴ) == 0
193- # jobz = ' N'
194- # econ = 0
195- # else
196- # jobz = ' V'
197- # size(U, 1) == m ||
198- # throw(DimensionMismatch("row size mismatch between A and U"))
199- # size(Vᴴ, 2) == n ||
200- # throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
201- # if size(U, 2) == size(Vᴴ, 1) == minmn
202- # econ = 1
203- # elseif size(U, 2) == m && size(Vᴴ, 1) == n
204- # econ = 0
205- # else
206- # throw(DimensionMismatch("invalid column size of U or row size of Vᴴ"))
207- # end
208- # end
209- # length(S) == minmn ||
210- # throw(DimensionMismatch("length mismatch between A and S"))
211-
212- # if jobz == ' N' # it seems we still need the memory for U and Vᴴ
213- # U = similar(A, $elty, m, minmn)
214- # V = similar(A, $elty, n, minmn)
215- # else
216- # V = similar(Vᴴ' )
217- # end
218- # lda = max(1, stride(A, 2))
219- # ldu = max(1, stride(U, 2))
220- # ldv = max(1, stride(V, 2))
174+ for (bname, fname, elty, relty) in
175+ ((:cusolverDnSgesvdj_bufferSize, :cusolverDnSgesvdj, :Float32, :Float32),
176+ (:cusolverDnDgesvdj_bufferSize, :cusolverDnDgesvdj, :Float64, :Float64),
177+ (:cusolverDnCgesvdj_bufferSize, :cusolverDnCgesvdj, :ComplexF32, :Float32),
178+ (:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64))
179+ @eval begin
180+ #! format: off
181+ function gesvdj!(A::StridedCuMatrix{$elty},
182+ S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)),
183+ U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)),
184+ Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2));
185+ tol::$relty=eps($relty),
186+ max_sweeps::Int=100)
187+ #! format: on
188+ chkstride1(A, U, Vᴴ, S)
189+ m, n = size(A)
190+ minmn = min(m, n)
221191
222- # params = Ref{gesvdjInfo_t}(C_NULL)
223- # cusolverDnCreateGesvdjInfo(params)
224- # cusolverDnXgesvdjSetTolerance(params[], tol)
225- # cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps)
226- # dh = dense_handle()
192+ if length(U) == 0 && length(Vᴴ) == 0
193+ jobz = ' N'
194+ econ = 0
195+ else
196+ jobz = ' V'
197+ size(U, 1) == m ||
198+ throw(DimensionMismatch("row size mismatch between A and U"))
199+ size(Vᴴ, 2) == n ||
200+ throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
201+ if size(U, 2) == size(Vᴴ, 1) == minmn
202+ econ = 1
203+ elseif size(U, 2) == m && size(Vᴴ, 1) == n
204+ econ = 0
205+ else
206+ throw(DimensionMismatch("invalid column size of U or row size of Vᴴ"))
207+ end
208+ end
209+ length(S) == minmn ||
210+ throw(DimensionMismatch("length mismatch between A and S"))
227211
228- # function bufferSize()
229- # out = Ref{Cint}(0)
230- # $bname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
231- # out, params[])
232- # return out[] * sizeof($elty)
233- # end
212+ Ṽ = (jobz == ' V' ) ? similar(Vᴴ' ) : similar(Vᴴ, (n, minmn))
213+ Ũ = (jobz == ' V' ) ? U : similar(U, (m, minmn))
214+ lda = max(1 , stride(A, 2 ))
215+ ldu = max(1 , stride(Ũ, 2 ))
216+ ldv = max(1 , stride(Ṽ, 2 ))
234217
235- # with_workspace(dh.workspace_gpu, bufferSize) do buffer
236- # return $fname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
237- # buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[])
238- # end
218+ params = Ref{CUSOLVER. gesvdjInfo_t}(C_NULL )
219+ CUSOLVER. cusolverDnCreateGesvdjInfo(params)
220+ CUSOLVER. cusolverDnXgesvdjSetTolerance(params[], tol)
221+ CUSOLVER. cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps)
222+ dh = CUSOLVER. dense_handle()
239223
240- # info = @allowscalar dh.info[1]
241- # chkargsok(BlasInt(info))
224+ function bufferSize()
225+ out = Ref{Cint}(0 )
226+ CUSOLVER.$ bname(dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv,
227+ out, params[])
228+ return out[] * sizeof($ elty)
229+ end
242230
243- # cusolverDnDestroyGesvdjInfo(params[])
231+ CUSOLVER. with_workspace(dh. workspace_gpu, bufferSize) do buffer
232+ return CUSOLVER.$ fname(dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv,
233+ buffer, sizeof(buffer) ÷ sizeof($ elty), dh. info,
234+ params[])
235+ end
244236
245- # if jobz != 'N'
246- # adjoint!(Vᴴ, V)
247- # end
248- # return U, S, Vᴴ
249- # end
250- # end
251- # end
237+ info = @allowscalar dh. info[1 ]
238+ CUSOLVER. chkargsok(BlasInt(info))
239+
240+ CUSOLVER. cusolverDnDestroyGesvdjInfo(params[])
241+
242+ if jobz == ' V'
243+ adjoint!(Vᴴ, Ṽ)
244+ end
245+ return U, S, Vᴴ
246+ end
247+ end
248+ end
252249
253250# for (jname, bname, fname, elty, relty) in
254251# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
0 commit comments