1+ const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
2+ CUSOLVER_SVDPolar,
3+ CUSOLVER_Jacobi}
4+
5+ # CUSOLVER SVD implementation
6+ function MatrixAlgebraKit. svd_full!(A:: CuMatrix , USVᴴ, alg:: CUSOLVER_SVDAlgorithm )
7+ check_input(svd_full!, A, USVᴴ)
8+ U, S, Vᴴ = USVᴴ
9+ fill!(S, zero(eltype(S)))
10+ m, n = size(A)
11+ minmn = min(m, n)
12+ if alg isa CUSOLVER_QRIteration
13+ isempty(alg. kwargs) ||
14+ throw(ArgumentError(" LAPACK_QRIteration does not accept any keyword arguments" ))
15+ YACUSOLVER. gesvd!(A, view(S, 1 : minmn, 1 ), U, Vᴴ)
16+ elseif alg isa CUSOLVER_SVDPolar
17+ YACUSOLVER. Xgesvdp!(A, view(S, 1 : minmn, 1 ), U, Vᴴ; alg. kwargs... )
18+ # elseif alg isa LAPACK_Bisection
19+ # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
20+ # elseif alg isa LAPACK_Jacobi
21+ # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
22+ else
23+ throw(ArgumentError(" Unsupported SVD algorithm" ))
24+ end
25+ diagview(S) .= view(S, 1 : minmn, 1 )
26+ view(S, 2 : minmn, 1 ) .= zero(eltype(S))
27+ # TODO : make this controllable using a `gaugefix` keyword argument
28+ for j in 1 : max(m, n)
29+ if j <= minmn
30+ u = view(U, :, j)
31+ v = view(Vᴴ, j, :)
32+ s = conj(sign(_argmaxabs(u)))
33+ u .*= s
34+ v .*= conj(s)
35+ elseif j <= m
36+ u = view(U, :, j)
37+ s = conj(sign(_argmaxabs(u)))
38+ u .*= s
39+ else
40+ v = view(Vᴴ, j, :)
41+ s = conj(sign(_argmaxabs(v)))
42+ v .*= s
43+ end
44+ end
45+ return USVᴴ
46+ end
47+
48+ function MatrixAlgebraKit. svd_compact!(A:: CuMatrix , USVᴴ, alg:: CUSOLVER_SVDAlgorithm )
49+ check_input(svd_compact!, A, USVᴴ)
50+ U, S, Vᴴ = USVᴴ
51+ if alg isa CUSOLVER_QRIteration
52+ isempty(alg. kwargs) ||
53+ throw(ArgumentError(" CUSOLVER_QRIteration does not accept any keyword arguments" ))
54+ YACUSOLVER. gesvd!(A, S. diag, U, Vᴴ)
55+ elseif alg isa CUSOLVER_SVDPolar
56+ YACUSOLVER. Xgesvdp!(A, S. diag, U, Vᴴ; alg. kwargs... )
57+ # elseif alg isa LAPACK_DivideAndConquer
58+ # isempty(alg.kwargs) ||
59+ # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
60+ # YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
61+ # elseif alg isa LAPACK_Bisection
62+ # YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...)
63+ # elseif alg isa LAPACK_Jacobi
64+ # isempty(alg.kwargs) ||
65+ # throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
66+ # YALAPACK.gesvj!(A, S.diag, U, Vᴴ)
67+ else
68+ throw(ArgumentError(" Unsupported SVD algorithm" ))
69+ end
70+ # TODO : make this controllable using a `gaugefix` keyword argument
71+ for j in 1 : size(U, 2 )
72+ u = view(U, :, j)
73+ v = view(Vᴴ, j, :)
74+ s = conj(sign(_argmaxabs(u)))
75+ u .*= s
76+ v .*= conj(s)
77+ end
78+ return USVᴴ
79+ end
80+ _argmaxabs(x) = reduce(_largest, x; init= zero(eltype(x)))
81+ _largest(x, y) = abs(x) < abs(y) ? y : x
82+
83+ function MatrixAlgebraKit. svd_vals!(A:: CuMatrix , S, alg:: CUSOLVER_SVDAlgorithm )
84+ check_input(svd_vals!, A, S)
85+ U, Vᴴ = similar(A, (0 , 0 )), similar(A, (0 , 0 ))
86+ if alg isa CUSOLVER_QRIteration
87+ isempty(alg. kwargs) ||
88+ throw(ArgumentError(" CUSOLVER_QRIteration does not accept any keyword arguments" ))
89+ YACUSOLVER. gesvd!(A, S, U, Vᴴ)
90+ elseif alg isa CUSOLVER_SVDPolar
91+ YACUSOLVER. Xgesvdp!(A, S, U, Vᴴ; alg. kwargs... )
92+ # elseif alg isa LAPACK_DivideAndConquer
93+ # isempty(alg.kwargs) ||
94+ # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
95+ # YALAPACK.gesdd!(A, S, U, Vᴴ)
96+ # elseif alg isa LAPACK_Bisection
97+ # YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...)
98+ # elseif alg isa LAPACK_Jacobi
99+ # isempty(alg.kwargs) ||
100+ # throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
101+ # YALAPACK.gesvj!(A, S, U, Vᴴ)
102+ else
103+ throw(ArgumentError(" Unsupported SVD algorithm" ))
104+ end
105+ return S
106+ end
0 commit comments