Skip to content

Commit fa50d31

Browse files
committed
first svd support
1 parent 13b5b0c commit fa50d31

File tree

14 files changed

+460
-180
lines changed

14 files changed

+460
-180
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
3939
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4040

4141
[targets]
42-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"]
42+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA"]

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@ using LinearAlgebra: BlasFloat
1111

1212
include("yacusolver.jl")
1313
include("implementations/qr.jl")
14-
include("implementations/lq.jl")
14+
include("implementations/svd.jl")
15+
16+
function MatrixAlgebraKit.default_qr_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...)
17+
return CUSOLVER_HouseholderQR(; kwargs...)
18+
end
19+
function MatrixAlgebraKit.default_lq_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...)
20+
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
21+
return LQViaTransposedQR(qr_alg)
22+
end
23+
function MatrixAlgebraKit.default_svd_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...)
24+
return CUSOLVER_QRIteration(; kwargs...)
25+
end
1526

1627
end

ext/MatrixAlgebraKitCUDAExt/implementations/lq.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,4 @@
1-
"""
2-
CUSOLVER_HouseholderQR(; positive = false)
3-
4-
Algorithm type to denote the standard CUSOLVER algorithm for computing the QR decomposition of
5-
a matrix using Householder reflectors. The keyword `positive=true` can be used to ensure that
6-
the diagonal elements of `R` are non-negative.
7-
"""
8-
@algdef CUSOLVER_HouseholderQR
9-
10-
function MatrixAlgebraKit.default_qr_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...)
11-
return CUSOLVER_HouseholderQR(; kwargs...)
12-
end
13-
14-
# Implementation
15-
# --------------
16-
# actual implementation
1+
# CUSOLVER QR implementation
172
function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR)
183
check_input(qr_full!, A, QR)
194
Q, R = QR
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
2+
CUSOLVER_SVDPolar,
3+
CUSOLVER_Jacobi}
4+
5+
# CUSOLVER SVD implementation
6+
function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm)
7+
check_input(svd_full!, A, USVᴴ)
8+
U, S, Vᴴ = USVᴴ
9+
fill!(S, zero(eltype(S)))
10+
m, n = size(A)
11+
minmn = min(m, n)
12+
if alg isa CUSOLVER_QRIteration
13+
isempty(alg.kwargs) ||
14+
throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments"))
15+
YACUSOLVER.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
16+
elseif alg isa CUSOLVER_SVDPolar
17+
YACUSOLVER.Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
18+
# elseif alg isa LAPACK_Bisection
19+
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
20+
# elseif alg isa LAPACK_Jacobi
21+
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
22+
else
23+
throw(ArgumentError("Unsupported SVD algorithm"))
24+
end
25+
diagview(S) .= view(S, 1:minmn, 1)
26+
view(S, 2:minmn, 1) .= zero(eltype(S))
27+
# TODO: make this controllable using a `gaugefix` keyword argument
28+
for j in 1:max(m, n)
29+
if j <= minmn
30+
u = view(U, :, j)
31+
v = view(Vᴴ, j, :)
32+
s = conj(sign(_argmaxabs(u)))
33+
u .*= s
34+
v .*= conj(s)
35+
elseif j <= m
36+
u = view(U, :, j)
37+
s = conj(sign(_argmaxabs(u)))
38+
u .*= s
39+
else
40+
v = view(Vᴴ, j, :)
41+
s = conj(sign(_argmaxabs(v)))
42+
v .*= s
43+
end
44+
end
45+
return USVᴴ
46+
end
47+
48+
function MatrixAlgebraKit.svd_compact!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm)
49+
check_input(svd_compact!, A, USVᴴ)
50+
U, S, Vᴴ = USVᴴ
51+
if alg isa CUSOLVER_QRIteration
52+
isempty(alg.kwargs) ||
53+
throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments"))
54+
YACUSOLVER.gesvd!(A, S.diag, U, Vᴴ)
55+
elseif alg isa CUSOLVER_SVDPolar
56+
YACUSOLVER.Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
57+
# elseif alg isa LAPACK_DivideAndConquer
58+
# isempty(alg.kwargs) ||
59+
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
60+
# YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
61+
# elseif alg isa LAPACK_Bisection
62+
# YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...)
63+
# elseif alg isa LAPACK_Jacobi
64+
# isempty(alg.kwargs) ||
65+
# throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
66+
# YALAPACK.gesvj!(A, S.diag, U, Vᴴ)
67+
else
68+
throw(ArgumentError("Unsupported SVD algorithm"))
69+
end
70+
# TODO: make this controllable using a `gaugefix` keyword argument
71+
for j in 1:size(U, 2)
72+
u = view(U, :, j)
73+
v = view(Vᴴ, j, :)
74+
s = conj(sign(_argmaxabs(u)))
75+
u .*= s
76+
v .*= conj(s)
77+
end
78+
return USVᴴ
79+
end
80+
_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))
81+
_largest(x, y) = abs(x) < abs(y) ? y : x
82+
83+
function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm)
84+
check_input(svd_vals!, A, S)
85+
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
86+
if alg isa CUSOLVER_QRIteration
87+
isempty(alg.kwargs) ||
88+
throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments"))
89+
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
90+
elseif alg isa CUSOLVER_SVDPolar
91+
YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
92+
# elseif alg isa LAPACK_DivideAndConquer
93+
# isempty(alg.kwargs) ||
94+
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
95+
# YALAPACK.gesdd!(A, S, U, Vᴴ)
96+
# elseif alg isa LAPACK_Bisection
97+
# YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...)
98+
# elseif alg isa LAPACK_Jacobi
99+
# isempty(alg.kwargs) ||
100+
# throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
101+
# YALAPACK.gesvj!(A, S, U, Vᴴ)
102+
else
103+
throw(ArgumentError("Unsupported SVD algorithm"))
104+
end
105+
return S
106+
end

0 commit comments

Comments
 (0)