Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ for (fname, elty, relty) in

AMDGPU.unsafe_free!(dev_residual)
AMDGPU.unsafe_free!(dev_n_sweeps)
return U, S, Vᴴ
return (S, U, Vᴴ)
end
end
end
Expand Down
3 changes: 2 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!
using CUDA
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand All @@ -30,6 +30,7 @@ _gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) = YACUSOLVER.ungqr!(A, τ)
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) = YACUSOLVER.unmqr!(side, trans, A, τ, C)
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) = YACUSOLVER.gesvd!(A, S, U, Vᴴ)
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)

end
61 changes: 60 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/yacusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,70 @@ for (bname, fname, elty, relty) in
if jobz == 'V'
adjoint!(Vᴴ, Ṽ)
end
return U, S, Vᴴ
return S, U, Vᴴ
end
end
end

# Wrapper for randomized SVD
function Xgesvdr!(A::StridedCuMatrix{T},
S::StridedCuVector=similar(A, real(T), min(size(A)...)),
U::StridedCuMatrix{T}=similar(A, T, size(A, 1), min(size(A)...)),
Vᴴ::StridedCuMatrix{T}=similar(A, T, min(size(A)...), size(A, 2));
k::Int=length(S),
p::Int=min(size(A)...)-k-1,
niters::Int=1) where {T<:BlasFloat}
chkstride1(A, U, S, Vᴴ)
m, n = size(A)
minmn = min(m, n)
jobu = length(U) == 0 ? 'N' : 'S'
jobv = length(Vᴴ) == 0 ? 'N' : 'S'
R = eltype(S)
k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)"))
k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)"))
R == real(T) ||
throw(ArgumentError("S does not have the matching real `eltype` of A"))

Ṽ = similar(Vᴴ, (n, n))
Ũ = (size(U) == (m, m)) ? U : similar(U, (m, m))
lda = max(1, stride(A, 2))
ldu = max(1, stride(Ũ, 2))
ldv = max(1, stride(Ṽ, 2))
params = CUSOLVER.CuSolverParameters()
dh = CUSOLVER.dense_handle()

function bufferSize()
out_cpu = Ref{Csize_t}(0)
out_gpu = Ref{Csize_t}(0)
CUSOLVER.cusolverDnXgesvdr_bufferSize(dh, params, jobu, jobv, m, n, k, p, niters,
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
T, out_gpu, out_cpu)

return out_gpu[], out_cpu[]
end
CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
bufferSize()...) do buffer_gpu, buffer_cpu
return CUSOLVER.cusolverDnXgesvdr(dh, params, jobu, jobv, m, n, k, p, niters,
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
T, buffer_gpu, sizeof(buffer_gpu),
buffer_cpu, sizeof(buffer_cpu),
dh.info)
end

flag = @allowscalar dh.info[1]
CUSOLVER.chklapackerror(BlasInt(flag))
if Ũ !== U && length(U) > 0
U .= view(Ũ, 1:m, 1:size(U, 2))
end
if length(Vᴴ) > 0
Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n)
end
Ũ !== U && CUDA.unsafe_free!(Ũ)
CUDA.unsafe_free!(Ṽ)

return S, U, Vᴴ
end

# for (jname, bname, fname, elty, relty) in
# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
# (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64),
Expand Down
2 changes: 1 addition & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi,
LQViaTransposedQR,
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi,
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized,
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered

Expand Down
2 changes: 1 addition & 1 deletion src/common/gauge.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function gaugefix!(V::AbstractMatrix)
for j in axes(V, 2)
v = view(V, :, j)
s = conj(sign(argmax(abs, v)))
s = conj(sign(_argmaxabs(v)))
@inbounds v .*= s
end
return V
Expand Down
8 changes: 4 additions & 4 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function copy_input(::typeof(eig_vals), A::AbstractMatrix)
end
copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A)

function check_input(::typeof(eig_full!), A::AbstractMatrix, DV)
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
D, V = DV
Expand All @@ -19,7 +19,7 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV)
@check_scalar(V, A, complex)
return nothing
end
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D)
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
@assert D isa AbstractVector
Expand Down Expand Up @@ -51,7 +51,7 @@ end
# --------------
# actual implementation
function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
check_input(eig_full!, A, DV)
check_input(eig_full!, A, DV, alg)
D, V = DV
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
Expand All @@ -66,7 +66,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
end

function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm)
check_input(eig_vals!, A, D)
check_input(eig_vals!, A, D, alg)
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
Expand Down
8 changes: 4 additions & 4 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix)
end
copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)

function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV)
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
D, V = DV
Expand All @@ -19,7 +19,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV)
@check_scalar(V, A)
return nothing
end
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D)
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
m, n = size(A)
@assert D isa AbstractVector
@check_size(D, (n,))
Expand Down Expand Up @@ -48,7 +48,7 @@ end
# Implementation
# --------------
function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
check_input(eigh_full!, A, DV)
check_input(eigh_full!, A, DV, alg)
D, V = DV
Dd = D.diag
if alg isa LAPACK_MultipleRelativelyRobustRepresentations
Expand All @@ -70,7 +70,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
end

function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm)
check_input(eigh_vals!, A, D)
check_input(eigh_vals!, A, D, alg)
V = similar(A, (size(A, 1), 0))
if alg isa LAPACK_MultipleRelativelyRobustRepresentations
YALAPACK.heevr!(A, D, V; alg.kwargs...)
Expand Down
8 changes: 4 additions & 4 deletions src/implementations/gen_eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix
return copy_input(gen_eig_full, A, B)
end

function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV)
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm)
ma, na = size(A)
mb, nb = size(B)
ma == na || throw(DimensionMismatch("square input matrix A expected"))
Expand All @@ -24,7 +24,7 @@ function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatr
@check_scalar(V, B, complex)
return nothing
end
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W)
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W, ::AbstractAlgorithm)
ma, na = size(A)
mb, nb = size(B)
ma == na || throw(DimensionMismatch("square input matrix A expected"))
Expand Down Expand Up @@ -57,7 +57,7 @@ end
# --------------
# actual implementation
function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm)
check_input(gen_eig_full!, A, B, WV)
check_input(gen_eig_full!, A, B, WV, alg)
W, V = WV
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
Expand All @@ -72,7 +72,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig
end

function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm)
check_input(gen_eig_vals!, A, B, W)
check_input(gen_eig_vals!, A, B, W, alg)
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
Expand Down
18 changes: 9 additions & 9 deletions src/implementations/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function copy_input(::typeof(lq_null), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end

function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ)
function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
m, n = size(A)
L, Q = LQ
@assert L isa AbstractMatrix && Q isa AbstractMatrix
Expand All @@ -20,7 +20,7 @@ function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ)
@check_scalar(Q, A)
return nothing
end
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ)
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
m, n = size(A)
minmn = min(m, n)
L, Q = LQ
Expand All @@ -31,7 +31,7 @@ function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ)
@check_scalar(Q, A)
return nothing
end
function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ)
function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgorithm)
m, n = size(A)
minmn = min(m, n)
@assert Nᴴ isa AbstractMatrix
Expand Down Expand Up @@ -66,36 +66,36 @@ end
# --------------
# actual implementation
function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
check_input(lq_full!, A, LQ)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
_lapack_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
check_input(lq_full!, A, LQ)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
lq_via_qr!(A, L, Q, alg.qr_alg)
return L, Q
end
function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
check_input(lq_compact!, A, LQ)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
_lapack_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
check_input(lq_compact!, A, LQ)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
lq_via_qr!(A, L, Q, alg.qr_alg)
return L, Q
end
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ)
check_input(lq_null!, A, Nᴴ)
check_input(lq_null!, A, Nᴴ, alg)
_lapack_lq_null!(A, Nᴴ; alg.kwargs...)
return Nᴴ
end
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR)
check_input(lq_null!, A, Nᴴ)
check_input(lq_null!, A, Nᴴ, alg)
lq_null_via_qr!(A, Nᴴ, alg.qr_alg)
return Nᴴ
end
Expand Down
Loading
Loading