Skip to content

Commit 8e1c8e7

Browse files
kshyattKatharine Hyatt
andauthored
[WIP] Attempting to wrap randomized SVD (#41)
* Attempting to wrap randomized SVD * Remove bad line resetting k * Incremental progress * Working svd_trunc * Move SVD gauge fix to its own function * Move randomized SVD to svd_trunc * Fix gaugefix arguments * Fix svd_trunc --------- Co-authored-by: Katharine Hyatt <[email protected]>
1 parent 1e15667 commit 8e1c8e7

File tree

17 files changed

+271
-147
lines changed

17 files changed

+271
-147
lines changed

ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ for (fname, elty, relty) in
164164

165165
AMDGPU.unsafe_free!(dev_residual)
166166
AMDGPU.unsafe_free!(dev_n_sweeps)
167-
return U, S, Vᴴ
167+
return (S, U, Vᴴ)
168168
end
169169
end
170170
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
9-
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
9+
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!
1010
using CUDA
1111
using LinearAlgebra
1212
using LinearAlgebra: BlasFloat
@@ -30,6 +30,7 @@ _gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) = YACUSOLVER.ungqr!(A, τ)
3030
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) = YACUSOLVER.unmqr!(side, trans, A, τ, C)
3131
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) = YACUSOLVER.gesvd!(A, S, U, Vᴴ)
3232
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
33+
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
3334
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
3435

3536
end

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,70 @@ for (bname, fname, elty, relty) in
242242
if jobz == 'V'
243243
adjoint!(Vᴴ, Ṽ)
244244
end
245-
return U, S, Vᴴ
245+
return S, U, Vᴴ
246246
end
247247
end
248248
end
249249

250+
# Wrapper for randomized SVD
251+
function Xgesvdr!(A::StridedCuMatrix{T},
252+
S::StridedCuVector=similar(A, real(T), min(size(A)...)),
253+
U::StridedCuMatrix{T}=similar(A, T, size(A, 1), min(size(A)...)),
254+
Vᴴ::StridedCuMatrix{T}=similar(A, T, min(size(A)...), size(A, 2));
255+
k::Int=length(S),
256+
p::Int=min(size(A)...)-k-1,
257+
niters::Int=1) where {T<:BlasFloat}
258+
chkstride1(A, U, S, Vᴴ)
259+
m, n = size(A)
260+
minmn = min(m, n)
261+
jobu = length(U) == 0 ? 'N' : 'S'
262+
jobv = length(Vᴴ) == 0 ? 'N' : 'S'
263+
R = eltype(S)
264+
k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)"))
265+
k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)"))
266+
R == real(T) ||
267+
throw(ArgumentError("S does not have the matching real `eltype` of A"))
268+
269+
= similar(Vᴴ, (n, n))
270+
= (size(U) == (m, m)) ? U : similar(U, (m, m))
271+
lda = max(1, stride(A, 2))
272+
ldu = max(1, stride(Ũ, 2))
273+
ldv = max(1, stride(Ṽ, 2))
274+
params = CUSOLVER.CuSolverParameters()
275+
dh = CUSOLVER.dense_handle()
276+
277+
function bufferSize()
278+
out_cpu = Ref{Csize_t}(0)
279+
out_gpu = Ref{Csize_t}(0)
280+
CUSOLVER.cusolverDnXgesvdr_bufferSize(dh, params, jobu, jobv, m, n, k, p, niters,
281+
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
282+
T, out_gpu, out_cpu)
283+
284+
return out_gpu[], out_cpu[]
285+
end
286+
CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
287+
bufferSize()...) do buffer_gpu, buffer_cpu
288+
return CUSOLVER.cusolverDnXgesvdr(dh, params, jobu, jobv, m, n, k, p, niters,
289+
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
290+
T, buffer_gpu, sizeof(buffer_gpu),
291+
buffer_cpu, sizeof(buffer_cpu),
292+
dh.info)
293+
end
294+
295+
flag = @allowscalar dh.info[1]
296+
CUSOLVER.chklapackerror(BlasInt(flag))
297+
if!== U && length(U) > 0
298+
U .= view(Ũ, 1:m, 1:size(U, 2))
299+
end
300+
if length(Vᴴ) > 0
301+
Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n)
302+
end
303+
Ũ !== U && CUDA.unsafe_free!(Ũ)
304+
CUDA.unsafe_free!(Ṽ)
305+
306+
return S, U, Vᴴ
307+
end
308+
250309
# for (jname, bname, fname, elty, relty) in
251310
# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
252311
# (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64),

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3333
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3434
LAPACK_DivideAndConquer, LAPACK_Jacobi,
3535
LQViaTransposedQR,
36-
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi,
36+
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized,
3737
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi
3838
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
3939

src/common/gauge.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function gaugefix!(V::AbstractMatrix)
22
for j in axes(V, 2)
33
v = view(V, :, j)
4-
s = conj(sign(argmax(abs, v)))
4+
s = conj(sign(_argmaxabs(v)))
55
@inbounds v .*= s
66
end
77
return V

src/implementations/eig.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function copy_input(::typeof(eig_vals), A::AbstractMatrix)
88
end
99
copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A)
1010

11-
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV)
11+
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
1212
m, n = size(A)
1313
m == n || throw(DimensionMismatch("square input matrix expected"))
1414
D, V = DV
@@ -19,7 +19,7 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV)
1919
@check_scalar(V, A, complex)
2020
return nothing
2121
end
22-
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D)
22+
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
2323
m, n = size(A)
2424
m == n || throw(DimensionMismatch("square input matrix expected"))
2525
@assert D isa AbstractVector
@@ -51,7 +51,7 @@ end
5151
# --------------
5252
# actual implementation
5353
function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
54-
check_input(eig_full!, A, DV)
54+
check_input(eig_full!, A, DV, alg)
5555
D, V = DV
5656
if alg isa LAPACK_Simple
5757
isempty(alg.kwargs) ||
@@ -66,7 +66,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
6666
end
6767

6868
function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm)
69-
check_input(eig_vals!, A, D)
69+
check_input(eig_vals!, A, D, alg)
7070
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
7171
if alg isa LAPACK_Simple
7272
isempty(alg.kwargs) ||

src/implementations/eigh.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix)
88
end
99
copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)
1010

11-
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV)
11+
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
1212
m, n = size(A)
1313
m == n || throw(DimensionMismatch("square input matrix expected"))
1414
D, V = DV
@@ -19,7 +19,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV)
1919
@check_scalar(V, A)
2020
return nothing
2121
end
22-
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D)
22+
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
2323
m, n = size(A)
2424
@assert D isa AbstractVector
2525
@check_size(D, (n,))
@@ -48,7 +48,7 @@ end
4848
# Implementation
4949
# --------------
5050
function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
51-
check_input(eigh_full!, A, DV)
51+
check_input(eigh_full!, A, DV, alg)
5252
D, V = DV
5353
Dd = D.diag
5454
if alg isa LAPACK_MultipleRelativelyRobustRepresentations
@@ -70,7 +70,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
7070
end
7171

7272
function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm)
73-
check_input(eigh_vals!, A, D)
73+
check_input(eigh_vals!, A, D, alg)
7474
V = similar(A, (size(A, 1), 0))
7575
if alg isa LAPACK_MultipleRelativelyRobustRepresentations
7676
YALAPACK.heevr!(A, D, V; alg.kwargs...)

src/implementations/gen_eig.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix
77
return copy_input(gen_eig_full, A, B)
88
end
99

10-
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV)
10+
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm)
1111
ma, na = size(A)
1212
mb, nb = size(B)
1313
ma == na || throw(DimensionMismatch("square input matrix A expected"))
@@ -24,7 +24,7 @@ function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatr
2424
@check_scalar(V, B, complex)
2525
return nothing
2626
end
27-
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W)
27+
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W, ::AbstractAlgorithm)
2828
ma, na = size(A)
2929
mb, nb = size(B)
3030
ma == na || throw(DimensionMismatch("square input matrix A expected"))
@@ -57,7 +57,7 @@ end
5757
# --------------
5858
# actual implementation
5959
function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm)
60-
check_input(gen_eig_full!, A, B, WV)
60+
check_input(gen_eig_full!, A, B, WV, alg)
6161
W, V = WV
6262
if alg isa LAPACK_Simple
6363
isempty(alg.kwargs) ||
@@ -72,7 +72,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig
7272
end
7373

7474
function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm)
75-
check_input(gen_eig_vals!, A, B, W)
75+
check_input(gen_eig_vals!, A, B, W, alg)
7676
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
7777
if alg isa LAPACK_Simple
7878
isempty(alg.kwargs) ||

src/implementations/lq.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function copy_input(::typeof(lq_null), A::AbstractMatrix)
1010
return copy!(similar(A, float(eltype(A))), A)
1111
end
1212

13-
function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ)
13+
function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
1414
m, n = size(A)
1515
L, Q = LQ
1616
@assert L isa AbstractMatrix && Q isa AbstractMatrix
@@ -20,7 +20,7 @@ function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ)
2020
@check_scalar(Q, A)
2121
return nothing
2222
end
23-
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ)
23+
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
2424
m, n = size(A)
2525
minmn = min(m, n)
2626
L, Q = LQ
@@ -31,7 +31,7 @@ function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ)
3131
@check_scalar(Q, A)
3232
return nothing
3333
end
34-
function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ)
34+
function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgorithm)
3535
m, n = size(A)
3636
minmn = min(m, n)
3737
@assert Nᴴ isa AbstractMatrix
@@ -66,36 +66,36 @@ end
6666
# --------------
6767
# actual implementation
6868
function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
69-
check_input(lq_full!, A, LQ)
69+
check_input(lq_full!, A, LQ, alg)
7070
L, Q = LQ
7171
_lapack_lq!(A, L, Q; alg.kwargs...)
7272
return L, Q
7373
end
7474
function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
75-
check_input(lq_full!, A, LQ)
75+
check_input(lq_full!, A, LQ, alg)
7676
L, Q = LQ
7777
lq_via_qr!(A, L, Q, alg.qr_alg)
7878
return L, Q
7979
end
8080
function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
81-
check_input(lq_compact!, A, LQ)
81+
check_input(lq_compact!, A, LQ, alg)
8282
L, Q = LQ
8383
_lapack_lq!(A, L, Q; alg.kwargs...)
8484
return L, Q
8585
end
8686
function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
87-
check_input(lq_compact!, A, LQ)
87+
check_input(lq_compact!, A, LQ, alg)
8888
L, Q = LQ
8989
lq_via_qr!(A, L, Q, alg.qr_alg)
9090
return L, Q
9191
end
9292
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ)
93-
check_input(lq_null!, A, Nᴴ)
93+
check_input(lq_null!, A, Nᴴ, alg)
9494
_lapack_lq_null!(A, Nᴴ; alg.kwargs...)
9595
return Nᴴ
9696
end
9797
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR)
98-
check_input(lq_null!, A, Nᴴ)
98+
check_input(lq_null!, A, Nᴴ, alg)
9999
lq_null_via_qr!(A, Nᴴ, alg.qr_alg)
100100
return Nᴴ
101101
end

0 commit comments

Comments
 (0)