diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 970e7600..1d6c2e32 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -2,7 +2,7 @@ env: SECRET_CODECOV_TOKEN: "MH6hHjQi7vG2V1Yfotv5/z5Dkx1k5SdyGYlGTFXiQr22XksJgsXaBuvFKUrjC7JwcpBsOVU8103LuMKl3m7VJ35WzHZrOssYycVbdGcb2kloc6xvUOsN2R5BrhCQ4Pii0l6ZeVRjCnZVkcmb0Rf4glGFyfibCrqniry8RLhblsuFKFsijRK4OxiWYEs1IvUulN+ER8tEsEtw4+ZqC5nbLGMSnUG/saPkDQOVIBscvikbKEnBcCXBheGPktF+Y/cy/1Xa+FiBPoZcypwTeAjKG1g0MqyHXjaYekb/7fekaj+hukGaeJSCXxY8KEb2IZCh+Y36Tp6y6qsIp/AdtEnCpQ==;U2FsdGVkX18WQxvGLspPwzC4aDe+U7TXU+itebTbgh8LUkE6GukxxReHYiDZ6IrBiVvSGTVJMquW0c8KsOI1pw==" steps: - - label: "Julia v1" + - label: "Julia v1 -- CUDA" plugins: - JuliaCI/julia#v1: version: "1" @@ -17,8 +17,7 @@ steps: if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 30 -steps: - - label: "Julia LTS" + - label: "Julia LTS -- CUDA" plugins: - JuliaCI/julia#v1: version: "1.10" # "lts" isn't valid @@ -32,3 +31,35 @@ steps: cuda: "*" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 30 + + - label: "Julia v1 -- AMDGPU" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-test#v1: ~ + - JuliaCI/julia-coverage#v1: + dirs: + - src + - ext + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 30 + + - label: "Julia LTS -- AMDGPU" + plugins: + - JuliaCI/julia#v1: + version: "1.10" # "lts" isn't valid + - JuliaCI/julia-test#v1: ~ + - JuliaCI/julia-coverage#v1: + dirs: + - src + - ext + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 30 diff --git a/Project.toml b/Project.toml index efc996b4..fb946847 100644 --- a/Project.toml +++ b/Project.toml @@ -7,14 +7,17 @@ version = "0.2.5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [extensions] MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" +MatrixAlgebraKitAMDGPUExt = "AMDGPU" MatrixAlgebraKitCUDAExt = "CUDA" [compat] +AMDGPU = "2" Aqua = "0.6, 0.7, 0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" @@ -39,4 +42,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"] diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl new file mode 100644 index 00000000..0f8c3513 --- /dev/null +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -0,0 +1,35 @@ +module MatrixAlgebraKitAMDGPUExt + +using MatrixAlgebraKit +using MatrixAlgebraKit: @algdef, Algorithm, check_input +using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! +using MatrixAlgebraKit: diagview, sign_safe +using MatrixAlgebraKit: LQViaTransposedQR +using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm +import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! +using AMDGPU +using LinearAlgebra +using LinearAlgebra: BlasFloat + +include("yarocsolver.jl") + +function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} + return ROCSOLVER_HouseholderQR(; kwargs...) +end +function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} + qr_alg = ROCSOLVER_HouseholderQR(; kwargs...) + return LQViaTransposedQR(qr_alg) +end +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} + return ROCSOLVER_QRIteration(; kwargs...) +end + +_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A) +_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ) +_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) = YArocSOLVER.unmqr!(side, trans, A, τ, C) +_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) = YArocSOLVER.gesvd!(A, S, U, Vᴴ) +# not yet supported +#_gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) +_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) + +end diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl new file mode 100644 index 00000000..97c8e6ee --- /dev/null +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -0,0 +1,519 @@ +module YArocSOLVER + +using LinearAlgebra +using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing +using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo + +using AMDGPU +using AMDGPU: @allowscalar +using AMDGPU.rocSOLVER +using AMDGPU.rocBLAS + +# QR methods are implemented with full access to allocated arrays, so we do not need to redo this: +using AMDGPU.rocSOLVER: geqrf!, ormqr!, orgqr! +const unmqr! = ormqr! +const ungqr! = orgqr! + +# Wrapper for SVD via QR Iteration +for (fname, elty, relty) in + ((:rocsolver_sgesvd, :Float32, :Float32), + (:rocsolver_dgesvd, :Float64, :Float64), + (:rocsolver_cgesvd, :ComplexF32, :Float32), + (:rocsolver_zgesvd, :ComplexF64, :Float64)) + @eval begin + #! format: off + function gesvd!(A::StridedROCMatrix{$elty}, + S::StridedROCVector{$relty}=similar(A, $relty, min(size(A)...)), + U::StridedROCMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)), + Vᴴ::StridedROCMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2))) + #! format: on + chkstride1(A, U, Vᴴ, S) + m, n = size(A) + (m < n) && throw(ArgumentError("rocSOLVER's gesvd requires m ≥ n")) + minmn = min(m, n) + if length(U) == 0 + jobu = rocSOLVER.rocblas_svect_none + else + size(U, 1) == m || + throw(DimensionMismatch("row size mismatch between A and U")) + if size(U, 2) == minmn + if U === A + jobu = rocSOLVER.rocblas_svect_overwrite + else + jobu = rocSOLVER.rocblas_svect_singular + end + elseif size(U, 2) == m + jobu = rocSOLVER.rocblas_svect_all + else + throw(DimensionMismatch("invalid column size of U")) + end + end + if length(Vᴴ) == 0 + jobvt = rocSOLVER.rocblas_svect_none + else + size(Vᴴ, 2) == n || + throw(DimensionMismatch("column size mismatch between A and Vᴴ")) + if size(Vᴴ, 1) == minmn + if Vᴴ === A + jobvt = rocSOLVER.rocblas_svect_overwrite + else + jobvt = rocSOLVER.rocblas_svect_singular + end + elseif size(Vᴴ, 1) == n + jobvt = rocSOLVER.rocblas_svect_all + else + throw(DimensionMismatch("invalid row size of Vᴴ")) + end + end + length(S) == minmn || + throw(DimensionMismatch("length mismatch between A and S")) + + lda = max(1, stride(A, 2)) + ldu = max(1, stride(U, 2)) + ldv = max(1, stride(Vᴴ, 2)) + + rwork = ROCArray{$relty}(undef, minmn - 1) + dh = rocBLAS.handle() + dev_info = ROCVector{Cint}(undef, 1) + rocSOLVER.$fname(dh, jobu, jobvt, m, n, + A, lda, S, U, ldu, Vᴴ, ldv, + rwork, convert(rocSOLVER.rocblas_workmode, 'I'), + dev_info) + AMDGPU.unsafe_free!(rwork) + + info = @allowscalar dev_info[1] + rocSOLVER.chkargsok(BlasInt(info)) + + return (S, U, Vᴴ) + end + end +end + +# Wrapper for SVD via Jacobi +for (fname, elty, relty) in + ((:rocsolver_sgesvdj, :Float32, :Float32), + (:rocsolver_dgesvdj, :Float64, :Float64), + (:rocsolver_cgesvdj, :ComplexF32, :Float32), + (:rocsolver_zgesvdj, :ComplexF64, :Float64)) + @eval begin + #! format: off + function gesvdj!(A::StridedROCMatrix{$elty}, + S::StridedROCVector{$relty}=similar(A, $relty, min(size(A)...)), + U::StridedROCMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)), + Vᴴ::StridedROCMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2)); + tol::$relty=eps($relty), + max_sweeps::Int=100, + ) + #! format: on + chkstride1(A, U, Vᴴ, S) + m, n = size(A) + minmn = min(m, n) + + if length(U) == 0 + jobu = rocSOLVER.rocblas_svect_none + else + size(U, 1) == m || + throw(DimensionMismatch("row size mismatch between A and U")) + if size(U, 2) == minmn + if U === A + throw(ArgumentError("overwrite mode is not supported for gesvdj")) + else + jobu = rocSOLVER.rocblas_svect_singular + end + elseif size(U, 2) == m + jobu = rocSOLVER.rocblas_svect_all + else + throw(DimensionMismatch("invalid column size of U")) + end + end + if length(Vᴴ) == 0 + jobvt = rocSOLVER.rocblas_svect_none + else + size(Vᴴ, 2) == n || + throw(DimensionMismatch("column size mismatch between A and Vᴴ")) + if size(Vᴴ, 1) == minmn + if Vᴴ === A + throw(ArgumentError("overwrite mode is not supported for gesvdj")) + else + jobvt = rocSOLVER.rocblas_svect_singular + end + elseif size(Vᴴ, 1) == n + jobvt = rocSOLVER.rocblas_svect_all + else + throw(DimensionMismatch("invalid row size of Vᴴ")) + end + end + length(S) == minmn || + throw(DimensionMismatch("length mismatch between A and S")) + + lda = max(1, stride(A, 2)) + ldu = max(1, stride(U, 2)) + ldv = max(1, stride(Vᴴ, 2)) + dev_info = ROCVector{Cint}(undef, 1) + dev_residual = ROCVector{$relty}(undef, 1) + dev_n_sweeps = ROCVector{Cint}(undef, 1) + + dh = rocBLAS.handle() + rocSOLVER.$fname(dh, jobu, jobvt, m, n, A, lda, tol, + dev_residual, max_sweeps, dev_n_sweeps, + S, U, ldu, Vᴴ, ldv, dev_info, + ) + + info = @allowscalar dev_info[1] + rocSOLVER.chkargsok(BlasInt(info)) + + AMDGPU.unsafe_free!(dev_residual) + AMDGPU.unsafe_free!(dev_n_sweeps) + #if jobvt == rocSOLVER.rocblas_svect_singular || jobvt == rocSOLVER.rocblas_svect_all + # adjoint!(Vᴴ, Ṽ) + #end + return U, S, Vᴴ + end + end +end + +# for (jname, bname, fname, elty, relty) in +# ((:sygvd!, :rocsolverDnSsygvd_bufferSize, :rocsolverDnSsygvd, :Float32, :Float32), +# (:sygvd!, :rocsolverDnDsygvd_bufferSize, :rocsolverDnDsygvd, :Float64, :Float64), +# (:hegvd!, :rocsolverDnChegvd_bufferSize, :rocsolverDnChegvd, :ComplexF32, :Float32), +# (:hegvd!, :rocsolverDnZhegvd_bufferSize, :rocsolverDnZhegvd, :ComplexF64, :Float64)) +# @eval begin +# function $jname(itype::Int, +# jobz::Char, +# uplo::Char, +# A::StridedROCMatrix{$elty}, +# B::StridedROCMatrix{$elty}) +# chkuplo(uplo) +# nA, nB = checksquare(A, B) +# if nB != nA +# throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!")) +# end +# n = nA +# lda = max(1, stride(A, 2)) +# ldb = max(1, stride(B, 2)) +# W = CuArray{$relty}(undef, n) +# dh = rocBLAS.handle() + +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, out) +# return out[] * sizeof($elty) +# end + +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, +# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) +# end + +# info = @allowscalar dh.info[1] +# chkargsok(BlasInt(info)) + +# if jobz == 'N' +# return W +# elseif jobz == 'V' +# return W, A, B +# end +# end +# end +# end + +# for (jname, bname, fname, elty, relty) in +# ((:sygvj!, :rocsolverDnSsygvj_bufferSize, :rocsolverDnSsygvj, :Float32, :Float32), +# (:sygvj!, :rocsolverDnDsygvj_bufferSize, :rocsolverDnDsygvj, :Float64, :Float64), +# (:hegvj!, :rocsolverDnChegvj_bufferSize, :rocsolverDnChegvj, :ComplexF32, :Float32), +# (:hegvj!, :rocsolverDnZhegvj_bufferSize, :rocsolverDnZhegvj, :ComplexF64, :Float64)) +# @eval begin +# function $jname(itype::Int, +# jobz::Char, +# uplo::Char, +# A::StridedROCMatrix{$elty}, +# B::StridedROCMatrix{$elty}; +# tol::$relty=eps($relty), +# max_sweeps::Int=100) +# chkuplo(uplo) +# nA, nB = checksquare(A, B) +# if nB != nA +# throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!")) +# end +# n = nA +# lda = max(1, stride(A, 2)) +# ldb = max(1, stride(B, 2)) +# W = CuArray{$relty}(undef, n) +# params = Ref{syevjInfo_t}(C_NULL) +# rocsolverDnCreateSyevjInfo(params) +# rocsolverDnXsyevjSetTolerance(params[], tol) +# rocsolverDnXsyevjSetMaxSweeps(params[], max_sweeps) +# dh = rocBLAS.handle() + +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, +# out, params[]) +# return out[] * sizeof($elty) +# end + +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, +# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) +# end + +# info = @allowscalar dh.info[1] +# chkargsok(BlasInt(info)) + +# rocsolverDnDestroySyevjInfo(params[]) + +# if jobz == 'N' +# return W +# elseif jobz == 'V' +# return W, A, B +# end +# end +# end +# end + +# for (jname, bname, fname, elty, relty) in +# ((:syevjBatched!, :rocsolverDnSsyevjBatched_bufferSize, :rocsolverDnSsyevjBatched, +# :Float32, :Float32), +# (:syevjBatched!, :rocsolverDnDsyevjBatched_bufferSize, :rocsolverDnDsyevjBatched, +# :Float64, :Float64), +# (:heevjBatched!, :rocsolverDnCheevjBatched_bufferSize, :rocsolverDnCheevjBatched, +# :ComplexF32, :Float32), +# (:heevjBatched!, :rocsolverDnZheevjBatched_bufferSize, :rocsolverDnZheevjBatched, +# :ComplexF64, :Float64)) +# @eval begin +# function $jname(jobz::Char, +# uplo::Char, +# A::StridedROCArray{$elty}; +# tol::$relty=eps($relty), +# max_sweeps::Int=100) + +# # Set up information for the solver arguments +# chkuplo(uplo) +# n = checksquare(A) +# lda = max(1, stride(A, 2)) +# batchSize = size(A, 3) +# W = CuArray{$relty}(undef, n, batchSize) +# params = Ref{syevjInfo_t}(C_NULL) + +# dh = rocBLAS.handle() +# resize!(dh.info, batchSize) + +# # Initialize the solver parameters +# rocsolverDnCreateSyevjInfo(params) +# rocsolverDnXsyevjSetTolerance(params[], tol) +# rocsolverDnXsyevjSetMaxSweeps(params[], max_sweeps) + +# # Calculate the workspace size +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, jobz, uplo, n, A, lda, W, out, params[], batchSize) +# return out[] * sizeof($elty) +# end + +# # Run the solver +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, jobz, uplo, n, A, lda, W, buffer, +# sizeof(buffer) ÷ sizeof($elty), dh.info, params[], batchSize) +# end + +# # Copy the solver info and delete the device memory +# info = @allowscalar collect(dh.info) + +# # Double check the solver's exit status +# for i in 1:batchSize +# chkargsok(BlasInt(info[i])) +# end + +# rocsolverDnDestroySyevjInfo(params[]) + +# # Return eigenvalues (in W) and possibly eigenvectors (in A) +# if jobz == 'N' +# return W +# elseif jobz == 'V' +# return W, A +# end +# end +# end +# end + +# for (fname, elty) in ((:rocsolverDnSpotrsBatched, :Float32), +# (:rocsolverDnDpotrsBatched, :Float64), +# (:rocsolverDnCpotrsBatched, :ComplexF32), +# (:rocsolverDnZpotrsBatched, :ComplexF64)) +# @eval begin +# function potrsBatched!(uplo::Char, +# A::Vector{<:StridedROCMatrix{$elty}}, +# B::Vector{<:StridedROCVecOrMat{$elty}}) +# if length(A) != length(B) +# throw(DimensionMismatch("")) +# end +# # Set up information for the solver arguments +# chkuplo(uplo) +# n = checksquare(A[1]) +# if size(B[1], 1) != n +# throw(DimensionMismatch("first dimension of B[i], $(size(B[1],1)), must match second dimension of A, $n")) +# end +# nrhs = size(B[1], 2) +# # cuSOLVER's Remark 1: only nrhs=1 is supported. +# if nrhs != 1 +# throw(ArgumentError("cuSOLVER only supports vectors for B")) +# end +# lda = max(1, stride(A[1], 2)) +# ldb = max(1, stride(B[1], 2)) +# batchSize = length(A) + +# Aptrs = unsafe_batch(A) +# Bptrs = unsafe_batch(B) + +# dh = rocBLAS.handle() + +# # Run the solver +# $fname(dh, uplo, n, nrhs, Aptrs, lda, Bptrs, ldb, dh.info, batchSize) + +# # Copy the solver info and delete the device memory +# info = @allowscalar dh.info[1] +# chklapackerror(BlasInt(info)) + +# return B +# end +# end +# end + +# for (fname, elty) in ((:rocsolverDnSpotrfBatched, :Float32), +# (:rocsolverDnDpotrfBatched, :Float64), +# (:rocsolverDnCpotrfBatched, :ComplexF32), +# (:rocsolverDnZpotrfBatched, :ComplexF64)) +# @eval begin +# function potrfBatched!(uplo::Char, A::Vector{<:StridedROCMatrix{$elty}}) + +# # Set up information for the solver arguments +# chkuplo(uplo) +# n = checksquare(A[1]) +# lda = max(1, stride(A[1], 2)) +# batchSize = length(A) + +# Aptrs = unsafe_batch(A) + +# dh = rocBLAS.handle() +# resize!(dh.info, batchSize) + +# # Run the solver +# $fname(dh, uplo, n, Aptrs, lda, dh.info, batchSize) + +# # Copy the solver info and delete the device memory +# info = @allowscalar collect(dh.info) + +# # Double check the solver's exit status +# for i in 1:batchSize +# chkargsok(BlasInt(info[i])) +# end + +# # info[i] > 0 means the leading minor of order info[i] is not positive definite +# # LinearAlgebra.LAPACK does not throw Exception here +# # to simplify calls to isposdef! and factorize +# return A, info +# end +# end +# end + +# # gesv +# function gesv!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true, +# residual_history::Bool=false, irs_precision::String="AUTO", +# refinement_solver::String="CLASSICAL", +# maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, +# tol_inner=Float64 = 0.0) where {T<:BlasFloat} +# params = CuSolverIRSParameters() +# info = CuSolverIRSInformation() +# n = checksquare(A) +# nrhs = size(B, 2) +# lda = max(1, stride(A, 2)) +# ldb = max(1, stride(B, 2)) +# ldx = max(1, stride(X, 2)) +# niters = Ref{Cint}() +# dh = rocBLAS.handle() + +# if irs_precision == "AUTO" +# (T == Float32) && (irs_precision = "R_32F") +# (T == Float64) && (irs_precision = "R_64F") +# (T == ComplexF32) && (irs_precision = "C_32F") +# (T == ComplexF64) && (irs_precision = "C_64F") +# else +# (T == Float32) && (irs_precision ∈ ("R_32F", "R_16F", "R_16BF", "R_TF32") || +# error("$irs_precision is not supported.")) +# (T == Float64) && +# (irs_precision ∈ ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || +# error("$irs_precision is not supported.")) +# (T == ComplexF32) && (irs_precision ∈ ("C_32F", "C_16F", "C_16BF", "C_TF32") || +# error("$irs_precision is not supported.")) +# (T == ComplexF64) && +# (irs_precision ∈ ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || +# error("$irs_precision is not supported.")) +# end +# rocsolverDnIRSParamsSetSolverMainPrecision(params, T) +# rocsolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision) +# rocsolverDnIRSParamsSetRefinementSolver(params, refinement_solver) +# (tol != 0.0) && rocsolverDnIRSParamsSetTol(params, tol) +# (tol_inner != 0.0) && rocsolverDnIRSParamsSetTolInner(params, tol_inner) +# (maxiters != 0) && rocsolverDnIRSParamsSetMaxIters(params, maxiters) +# (maxiters_inner != 0) && rocsolverDnIRSParamsSetMaxItersInner(params, maxiters_inner) +# fallback ? rocsolverDnIRSParamsEnableFallback(params) : +# rocsolverDnIRSParamsDisableFallback(params) +# residual_history && rocsolverDnIRSInfosRequestResidual(info) + +# function bufferSize() +# buffer_size = Ref{Csize_t}(0) +# rocsolverDnIRSXgesv_bufferSize(dh, params, n, nrhs, buffer_size) +# return buffer_size[] +# end + +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return rocsolverDnIRSXgesv(dh, params, info, n, nrhs, A, lda, B, ldb, +# X, ldx, buffer, sizeof(buffer), niters, dh.info) +# end + +# # Copy the solver flag and delete the device memory +# flag = @allowscalar dh.info[1] +# chklapackerror(BlasInt(flag)) + +# return X, info +# end + +# for (jname, bname, fname, elty, relty) in +# ((:syevd!, :rocsolverDnSsyevd_bufferSize, :rocsolverDnSsyevd, :Float32, :Float32), +# (:syevd!, :rocsolverDnDsyevd_bufferSize, :rocsolverDnDsyevd, :Float64, :Float64), +# (:heevd!, :rocsolverDnCheevd_bufferSize, :rocsolverDnCheevd, :ComplexF32, :Float32), +# (:heevd!, :rocsolverDnZheevd_bufferSize, :rocsolverDnZheevd, :ComplexF64, :Float64)) +# @eval begin +# function $jname(jobz::Char, +# uplo::Char, +# A::StridedROCMatrix{$elty}) +# chkuplo(uplo) +# n = checksquare(A) +# lda = max(1, stride(A, 2)) +# W = CuArray{$relty}(undef, n) +# dh = rocBLAS.handle() + +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, jobz, uplo, n, A, lda, W, out) +# return out[] * sizeof($elty) +# end + +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, jobz, uplo, n, A, lda, W, +# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) +# end + +# info = @allowscalar dh.info[1] +# chkargsok(BlasInt(info)) + +# if jobz == 'N' +# return W +# elseif jobz == 'V' +# return W, A +# end +# end +# end +# end + +end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 01a721f7..9fbc2c4d 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -5,14 +5,13 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: LQViaTransposedQR -using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm +using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm +import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! using CUDA using LinearAlgebra using LinearAlgebra: BlasFloat include("yacusolver.jl") -include("implementations/qr.jl") -include("implementations/svd.jl") function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} return CUSOLVER_HouseholderQR(; kwargs...) @@ -25,4 +24,12 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<: return CUSOLVER_QRIteration(; kwargs...) end + +_gpu_geqrf!(A::StridedCuMatrix) = YACUSOLVER.geqrf!(A) +_gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) = YACUSOLVER.ungqr!(A, τ) +_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) = YACUSOLVER.unmqr!(side, trans, A, τ, C) +_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) = YACUSOLVER.gesvd!(A, S, U, Vᴴ) +_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) +_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) + end diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl deleted file mode 100644 index 636cf82e..00000000 --- a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl +++ /dev/null @@ -1,69 +0,0 @@ -# CUSOLVER QR implementation -function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR) - check_input(qr_full!, A, QR) - Q, R = QR - _cusolver_qr!(A, Q, R; alg.kwargs...) - return Q, R -end -function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR) - check_input(qr_compact!, A, QR) - Q, R = QR - _cusolver_qr!(A, Q, R; alg.kwargs...) - return Q, R -end -function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::CUSOLVER_HouseholderQR) - check_input(qr_null!, A, N) - _cusolver_qr_null!(A, N; alg.kwargs...) - return N -end - -function _cusolver_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; - positive=false, blocksize=1) - blocksize > 1 && - throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition")) - m, n = size(A) - minmn = min(m, n) - computeR = length(R) > 0 - inplaceQ = Q === A - if inplaceQ && (computeR || positive || m < n) - throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required and using `positive=false`")) - end - - A, τ = YACUSOLVER.geqrf!(A) - if inplaceQ - Q = YACUSOLVER.ungqr!(A, τ) - else - Q = YACUSOLVER.unmqr!('L', 'N', A, τ, one!(Q)) - end - # henceforth, τ is no longer needed and can be reused - - if positive # already fix Q even if we do not need R - # TODO: report that `lmul!` and `rmul!` with `Diagonal` don't work with CUDA - τ .= sign_safe.(diagview(A)) - Qf = view(Q, 1:m, 1:minmn) # first minmn columns of Q - Qf .= Qf .* transpose(τ) - end - - if computeR - R̃ = uppertriangular!(view(A, axes(R)...)) - if positive - R̃f = view(R̃, 1:minmn, 1:n) # first minmn rows of R - R̃f .= conj.(τ) .* R̃f - end - copyto!(R, R̃) - end - return Q, R -end - -function _cusolver_qr_null!(A::AbstractMatrix, N::AbstractMatrix; - positive=false, blocksize=1) - blocksize > 1 && - throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition")) - m, n = size(A) - minmn = min(m, n) - fill!(N, zero(eltype(N))) - one!(view(N, (minmn + 1):m, 1:(m - minmn))) - A, τ = YACUSOLVER.geqrf!(A) - N = YACUSOLVER.unmqr!('L', 'N', A, τ, N) - return N -end diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl deleted file mode 100644 index f9976031..00000000 --- a/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl +++ /dev/null @@ -1,108 +0,0 @@ -const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration, - CUSOLVER_SVDPolar, - CUSOLVER_Jacobi} - -# CUSOLVER SVD implementation -function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm) - check_input(svd_full!, A, USVᴴ) - U, S, Vᴴ = USVᴴ - fill!(S, zero(eltype(S))) - m, n = size(A) - minmn = min(m, n) - if alg isa CUSOLVER_QRIteration - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments")) - YACUSOLVER.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) - elseif alg isa CUSOLVER_SVDPolar - YACUSOLVER.Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) - elseif alg isa CUSOLVER_Jacobi - YACUSOLVER.gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) - # elseif alg isa LAPACK_Bisection - # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) - # elseif alg isa LAPACK_Jacobi - # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end - diagview(S) .= view(S, 1:minmn, 1) - view(S, 2:minmn, 1) .= zero(eltype(S)) - # TODO: make this controllable using a `gaugefix` keyword argument - for j in 1:max(m, n) - if j <= minmn - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) - elseif j <= m - u = view(U, :, j) - s = conj(sign(_argmaxabs(u))) - u .*= s - else - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(v))) - v .*= s - end - end - return USVᴴ -end - -function MatrixAlgebraKit.svd_compact!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm) - check_input(svd_compact!, A, USVᴴ) - U, S, Vᴴ = USVᴴ - if alg isa CUSOLVER_QRIteration - isempty(alg.kwargs) || - throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments")) - YACUSOLVER.gesvd!(A, S.diag, U, Vᴴ) - elseif alg isa CUSOLVER_SVDPolar - YACUSOLVER.Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...) - elseif alg isa CUSOLVER_Jacobi - YACUSOLVER.gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...) - # elseif alg isa LAPACK_DivideAndConquer - # isempty(alg.kwargs) || - # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments")) - # YALAPACK.gesdd!(A, S.diag, U, Vᴴ) - # elseif alg isa LAPACK_Bisection - # YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end - # TODO: make this controllable using a `gaugefix` keyword argument - for j in 1:size(U, 2) - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) - end - return USVᴴ -end -_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x))) -_largest(x, y) = abs(x) < abs(y) ? y : x - -function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm) - check_input(svd_vals!, A, S) - U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) - if alg isa CUSOLVER_QRIteration - isempty(alg.kwargs) || - throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments")) - YACUSOLVER.gesvd!(A, S, U, Vᴴ) - elseif alg isa CUSOLVER_SVDPolar - YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...) - elseif alg isa CUSOLVER_Jacobi - YACUSOLVER.gesvdj!(A, S, U, Vᴴ; alg.kwargs...) - # elseif alg isa LAPACK_DivideAndConquer - # isempty(alg.kwargs) || - # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments")) - # YALAPACK.gesdd!(A, S, U, Vᴴ) - # elseif alg isa LAPACK_Bisection - # YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...) - # elseif alg isa LAPACK_Jacobi - # isempty(alg.kwargs) || - # throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments")) - # YALAPACK.gesvj!(A, S, U, Vᴴ) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end - return S -end \ No newline at end of file diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index b020d756..03111f3c 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -591,4 +591,4 @@ end # end # end -end \ No newline at end of file +end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index d3b45845..9e3685bb 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -33,7 +33,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi, LQViaTransposedQR, - CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi + CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, + ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered VERSION >= v"1.11.0-DEV.469" && diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index 1d30b4b1..5b45eff9 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -166,3 +166,81 @@ function _lapack_qr_null!(A::AbstractMatrix, N::AbstractMatrix; end return N end + +### GPU logic +# placed here to avoid code duplication since much of the logic is replicable across +# CUDA and AMDGPU +### +function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) + check_input(qr_full!, A, QR) + Q, R = QR + _gpu_qr!(A, Q, R; alg.kwargs...) + return Q, R +end +function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) + check_input(qr_compact!, A, QR) + Q, R = QR + _gpu_qr!(A, Q, R; alg.kwargs...) + return Q, R +end +function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) + check_input(qr_null!, A, N) + _gpu_qr_null!(A, N; alg.kwargs...) + return N +end + +_gpu_geqrf!(A::AbstractMatrix) = throw(MethodError(_gpu_geqrf!, (A,))) +_gpu_ungqr!(A::AbstractMatrix, τ::AbstractVector) = throw(MethodError(_gpu_ungqr!, (A, τ))) +_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::AbstractMatrix, τ::AbstractVector, C) = throw(MethodError(_gpu_unmqr!, (side, trans, A, τ, C))) + + +function _gpu_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; + positive=false, blocksize=1) + blocksize > 1 && + throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition")) + m, n = size(A) + minmn = min(m, n) + computeR = length(R) > 0 + inplaceQ = Q === A + if inplaceQ && (computeR || positive || m < n) + throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required and using `positive=false`")) + end + + A, τ = _gpu_geqrf!(A) + if inplaceQ + Q = _gpu_ungqr!(A, τ) + else + Q = _gpu_unmqr!('L', 'N', A, τ, one!(Q)) + end + # henceforth, τ is no longer needed and can be reused + + if positive # already fix Q even if we do not need R + # TODO: report that `lmul!` and `rmul!` with `Diagonal` don't work with CUDA + τ .= sign_safe.(diagview(A)) + Qf = view(Q, 1:m, 1:minmn) # first minmn columns of Q + Qf .= Qf .* transpose(τ) + end + + if computeR + R̃ = uppertriangular!(view(A, axes(R)...)) + if positive + R̃f = view(R̃, 1:minmn, 1:n) # first minmn rows of R + R̃f .= conj.(τ) .* R̃f + end + copyto!(R, R̃) + end + return Q, R +end + +function _gpu_qr_null!(A::AbstractMatrix, N::AbstractMatrix; + positive=false, blocksize=1) + blocksize > 1 && + throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition")) + m, n = size(A) + minmn = min(m, n) + fill!(N, zero(eltype(N))) + one!(view(N, (minmn + 1):m, 1:(m - minmn))) + A, τ = _gpu_geqrf!(A) + N = _gpu_unmqr!('L', 'N', A, τ, N) + return N +end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 031494f8..fa69b3d3 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -178,3 +178,111 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm) USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg) return truncate!(svd_trunc!, USVᴴ′, alg.trunc) end + +### GPU logic +# placed here to avoid code duplication since much of the logic is replicable across +# CUDA and AMDGPU +### +const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration, + CUSOLVER_SVDPolar, + CUSOLVER_Jacobi} +const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, + ROCSOLVER_Jacobi} +const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} + +const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} +const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} +const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} + +_gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) = throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ))) +_gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ))) +_gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ))) + +# GPU SVD implementation +function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) + check_input(svd_full!, A, USVᴴ) + U, S, Vᴴ = USVᴴ + fill!(S, zero(eltype(S))) + m, n = size(A) + minmn = min(m, n) + if alg isa GPU_QRIteration + isempty(alg.kwargs) || + throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) + _gpu_gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) + elseif alg isa GPU_SVDPolar + _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) + elseif alg isa GPU_Jacobi + _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) + # elseif alg isa LAPACK_Bisection + # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) + # elseif alg isa LAPACK_Jacobi + # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) + else + throw(ArgumentError("Unsupported SVD algorithm")) + end + diagview(S) .= view(S, 1:minmn, 1) + view(S, 2:minmn, 1) .= zero(eltype(S)) + # TODO: make this controllable using a `gaugefix` keyword argument + for j in 1:max(m, n) + if j <= minmn + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + elseif j <= m + u = view(U, :, j) + s = conj(sign(_argmaxabs(u))) + u .*= s + else + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(v))) + v .*= s + end + end + return USVᴴ +end + +function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) + check_input(svd_compact!, A, USVᴴ) + U, S, Vᴴ = USVᴴ + if alg isa GPU_QRIteration + isempty(alg.kwargs) || + throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) + _gpu_gesvd!(A, S.diag, U, Vᴴ) + elseif alg isa GPU_SVDPolar + _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...) + elseif alg isa GPU_Jacobi + _gpu_gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...) + else + throw(ArgumentError("Unsupported SVD algorithm")) + end + # TODO: make this controllable using a `gaugefix` keyword argument + for j in 1:size(U, 2) + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + end + return USVᴴ +end +_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x))) +_largest(x, y) = abs(x) < abs(y) ? y : x + +function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) + check_input(svd_vals!, A, S) + U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) + if alg isa GPU_QRIteration + isempty(alg.kwargs) || + throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) + _gpu_gesvd!(A, S, U, Vᴴ) + elseif alg isa GPU_SVDPolar + _gpu_Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...) + elseif alg isa GPU_Jacobi + _gpu_gesvdj!(A, S, U, Vᴴ; alg.kwargs...) + else + throw(ArgumentError("Unsupported SVD algorithm")) + end + return S +end diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index bff490a9..9485c745 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -150,3 +150,32 @@ Algorithm type to denote the CUSOLVER driver for computing the singular value de a general matrix using the Jacobi algorithm. """ @algdef CUSOLVER_Jacobi + +# ========================= +# ROCSOLVER ALGORITHMS +# ========================= +""" + ROCSOLVER_HouseholderQR(; positive = false) + +Algorithm type to denote the standard ROCSOLVER algorithm for computing the QR decomposition of +a matrix using Householder reflectors. The keyword `positive=true` can be used to ensure that +the diagonal elements of `R` are non-negative. +""" +@algdef ROCSOLVER_HouseholderQR + +""" + ROCSOLVER_QRIteration() + +Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +QR Iteration algorithm. +""" +@algdef ROCSOLVER_QRIteration + +""" + ROCSOLVER_Jacobi() + +Algorithm type to denote the ROCSOLVER driver for computing the singular value decomposition of +a general matrix using the Jacobi algorithm. +""" +@algdef ROCSOLVER_Jacobi diff --git a/test/amd/lq.jl b/test/amd/lq.jl new file mode 100644 index 00000000..5a5f5340 --- /dev/null +++ b/test/amd/lq.jl @@ -0,0 +1,117 @@ +using MatrixAlgebraKit +using MatrixAlgebraKit: diagview +using Test +using TestExtras +using StableRNGs +using AMDGPU + +include("utilities.jl") + +@testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + A = ROCArray(randn(rng, T, m, n)) + L, Q = @constinferred lq_compact(A) + @test L isa ROCMatrix{T} && size(L) == (m, minmn) + @test Q isa ROCMatrix{T} && size(Q) == (minmn, n) + @test L * Q ≈ A + @test isapproxone(Q * Q') + Nᴴ = @constinferred lq_null(A) + @test Nᴴ isa ROCMatrix{T} && size(Nᴴ) == (n - minmn, n) + @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) + @test isapproxone(Nᴴ * Nᴴ') + + Ac = similar(A) + L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q)) + @test L2 === L + @test Q2 === Q + Nᴴ2 = @constinferred lq_null!(copy!(Ac, A), Nᴴ) + @test Nᴴ2 === Nᴴ + + # noL + noL = similar(A, 0, minmn) + Q2 = similar(Q) + lq_compact!(copy!(Ac, A), (noL, Q2)) + @test Q == Q2 + + # positive + lq_compact!(copy!(Ac, A), (L, Q); positive=true) + @test L * Q ≈ A + @test isapproxone(Q * Q') + @test all(>=(zero(real(T))), real(diagview(L))) + lq_compact!(copy!(Ac, A), (noL, Q2); positive=true) + @test Q == Q2 + + # explicit blocksize + lq_compact!(copy!(Ac, A), (L, Q); blocksize=1) + @test L * Q ≈ A + @test isapproxone(Q * Q') + lq_compact!(copy!(Ac, A), (noL, Q2); blocksize=1) + @test Q == Q2 + lq_null!(copy!(Ac, A), Nᴴ; blocksize=1) + @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) + @test isapproxone(Nᴴ * Nᴴ') + if m <= n + lq_compact!(copy!(Q2, A), (noL, Q2); blocksize=1) # in-place Q + @test Q ≈ Q2 + # these do not work because of the in-place Q + @test_throws ArgumentError lq_compact!(copy!(Q2, A), (L, Q2); blocksize=1) + @test_throws ArgumentError lq_compact!(copy!(Q2, A), (noL, Q2); positive=true) + end + # no blocked CUDA + @test_throws ArgumentError lq_compact!(copy!(Q2, A), (L, Q2); blocksize=8) + end +end + +@testset "lq_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + A = ROCArray(randn(rng, T, m, n)) + L, Q = lq_full(A) + @test L isa ROCMatrix{T} && size(L) == (m, n) + @test Q isa ROCMatrix{T} && size(Q) == (n, n) + @test L * Q ≈ A + @test isapproxone(Q * Q') + + Ac = similar(A) + L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q)) + @test L2 === L + @test Q2 === Q + @test L * Q ≈ A + @test isapproxone(Q * Q') + + # noL + noL = similar(A, 0, n) + Q2 = similar(Q) + lq_full!(copy!(Ac, A), (noL, Q2)) + @test Q == Q2 + + # positive + lq_full!(copy!(Ac, A), (L, Q); positive=true) + @test L * Q ≈ A + @test isapproxone(Q * Q') + @test all(>=(zero(real(T))), real(diagview(L))) + lq_full!(copy!(Ac, A), (noL, Q2); positive=true) + @test Q == Q2 + + # explicit blocksize + lq_full!(copy!(Ac, A), (L, Q); blocksize=1) + @test L * Q ≈ A + @test isapproxone(Q * Q') + lq_full!(copy!(Ac, A), (noL, Q2); blocksize=1) + @test Q == Q2 + if n == m + lq_full!(copy!(Q2, A), (noL, Q2); blocksize=1) # in-place Q + @test Q ≈ Q2 + # these do not work because of the in-place Q + @test_throws ArgumentError lq_full!(copy!(Q2, A), (L, Q2); blocksize=1) + @test_throws ArgumentError lq_full!(copy!(Q2, A), (noL, Q2); positive=true) + end + # no blocked CUDA + @test_throws ArgumentError lq_full!(copy!(Ac, A), (L, Q); blocksize=8) + end +end diff --git a/test/amd/qr.jl b/test/amd/qr.jl new file mode 100644 index 00000000..694963b5 --- /dev/null +++ b/test/amd/qr.jl @@ -0,0 +1,122 @@ +using MatrixAlgebraKit +using MatrixAlgebraKit: diagview +using Test +using TestExtras +using StableRNGs +using AMDGPU + +include("utilities.jl") + +@testset "qr_compact! and qr_null! for T = $T" for T in (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + A = ROCArray(randn(rng, T, m, n)) + Q, R = @constinferred qr_compact(A) + @test Q isa ROCMatrix{T} && size(Q) == (m, minmn) + @test R isa ROCMatrix{T} && size(R) == (minmn, n) + @test Q * R ≈ A + N = @constinferred qr_null(A) + @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) + @test isapproxone(Q' * Q) + @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) + @test isapproxone(N' * N) + + Ac = similar(A) + Q2, R2 = @constinferred qr_compact!(copy!(Ac, A), (Q, R)) + @test Q2 === Q + @test R2 === R + N2 = @constinferred qr_null!(copy!(Ac, A), N) + @test N2 === N + + # noR + Q2 = similar(Q) + noR = similar(A, minmn, 0) + qr_compact!(copy!(Ac, A), (Q2, noR)) + @test Q == Q2 + + # positive + qr_compact!(copy!(Ac, A), (Q, R); positive=true) + @test Q * R ≈ A + @test isapproxone(Q' * Q) + @test all(>=(zero(real(T))), real(diagview(R))) + qr_compact!(copy!(Ac, A), (Q2, noR); positive=true) + @test Q == Q2 + + # explicit blocksize + qr_compact!(copy!(Ac, A), (Q, R); blocksize=1) + @test Q * R ≈ A + @test isapproxone(Q' * Q) + qr_compact!(copy!(Ac, A), (Q2, noR); blocksize=1) + @test Q == Q2 + qr_compact!(copy!(Ac, A), (Q2, noR); blocksize=1) + qr_null!(copy!(Ac, A), N; blocksize=1) + @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) + @test isapproxone(N' * N) + if n <= m + qr_compact!(copy!(Q2, A), (Q2, noR); blocksize=1) # in-place Q + @test Q ≈ Q2 + # these do not work because of the in-place Q + @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, R2)) + @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, noR); positive=true) + end + # no blocked CUDA + @test_throws ArgumentError qr_compact!(copy!(Ac, A), (Q2, R); blocksize=8) + end +end + +@testset "qr_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 63 + for n in (37, m, 63) + minmn = min(m, n) + A = ROCArray(randn(rng, T, m, n)) + Q, R = qr_full(A) + @test Q isa ROCMatrix{T} && size(Q) == (m, m) + @test R isa ROCMatrix{T} && size(R) == (m, n) + @test Q * R ≈ A + @test isapproxone(Q' * Q) + + Ac = similar(A) + Q2 = similar(Q) + noR = similar(A, m, 0) + Q2, R2 = @constinferred qr_full!(copy!(Ac, A), (Q, R)) + @test Q2 === Q + @test R2 === R + @test Q * R ≈ A + @test isapproxone(Q' * Q) + qr_full!(copy!(Ac, A), (Q2, noR)) + @test Q == Q2 + + # noR + noR = similar(A, m, 0) + Q2 = similar(Q) + qr_full!(copy!(Ac, A), (Q2, noR)) + @test Q == Q2 + + # positive + qr_full!(copy!(Ac, A), (Q, R); positive=true) + @test Q * R ≈ A + @test isapproxone(Q' * Q) + @test all(>=(zero(real(T))), real(diagview(R))) + qr_full!(copy!(Ac, A), (Q2, noR); positive=true) + @test Q == Q2 + + # explicit blocksize + qr_full!(copy!(Ac, A), (Q, R); blocksize=1) + @test Q * R ≈ A + @test isapproxone(Q' * Q) + qr_full!(copy!(Ac, A), (Q2, noR); blocksize=1) + @test Q == Q2 + if n == m + qr_full!(copy!(Q2, A), (Q2, noR); blocksize=1) # in-place Q + @test Q ≈ Q2 + @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, R2)) + @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, noR); positive=true) + end + # no blocked CUDA + @test_throws ArgumentError qr_full!(copy!(Ac, A), (Q, R); blocksize=8) + end +end diff --git a/test/amd/svd.jl b/test/amd/svd.jl new file mode 100644 index 00000000..8442b7cc --- /dev/null +++ b/test/amd/svd.jl @@ -0,0 +1,120 @@ +using MatrixAlgebraKit +using MatrixAlgebraKit: diagview +using LinearAlgebra: Diagonal, isposdef +using Test +using TestExtras +using StableRNGs +using AMDGPU + +include("utilities.jl") + +@testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + @testset "size ($m, $n)" for n in (37, m, 63) + k = min(m, n) + algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) + @testset "algorithm $alg" for alg in algs + n > m && alg isa ROCSOLVER_QRIteration && continue # not supported + minmn = min(m, n) + A = ROCArray(randn(rng, T, m, n)) + + U, S, Vᴴ = svd_compact(A; alg) + @test U isa ROCMatrix{T} && size(U) == (m, minmn) + @test S isa Diagonal{real(T),<:ROCVector} && size(S) == (minmn, minmn) + @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n) + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(Vᴴ * Vᴴ') + @test isposdef(S) + + Ac = similar(A) + U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg) + @test U2 === U + @test S2 === S + @test V2ᴴ === Vᴴ + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(Vᴴ * Vᴴ') + @test isposdef(S) + + Sd = svd_vals(A, alg) + @test ROCArray(diagview(S)) ≈ Sd + # ROCArray is necessary because norm of ROCArray view with non-unit step is broken + end + end +end + +@testset "svd_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + @testset "size ($m, $n)" for n in (37, m, 63) + algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) + @testset "algorithm $alg" for alg in algs + n > m && alg isa ROCSOLVER_QRIteration && continue # not supported + A = ROCArray(randn(rng, T, m, n)) + U, S, Vᴴ = svd_full(A; alg) + @test U isa ROCMatrix{T} && size(U) == (m, m) + @test S isa ROCMatrix{real(T)} && size(S) == (m, n) + @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (n, n) + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(U * U') + @test isapproxone(Vᴴ * Vᴴ') + @test isapproxone(Vᴴ' * Vᴴ) + @test all(isposdef, diagview(S)) + + Ac = similar(A) + U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) + @test U2 === U + @test S2 === S + @test V2ᴴ === Vᴴ + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(U * U') + @test isapproxone(Vᴴ * Vᴴ') + @test isapproxone(Vᴴ' * Vᴴ) + @test all(isposdef, diagview(S)) + + Sc = similar(A, real(T), min(m, n)) + Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) + @test Sc === Sc2 + @test ROCArray(diagview(S)) ≈ Sc + # ROCArray is necessary because norm of ROCArray view with non-unit step is broken + end + end +end + +# @testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +# rng = StableRNG(123) +# m = 54 +# if LinearAlgebra.LAPACK.version() < v"3.12.0" +# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) +# else +# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), +# LAPACK_Jacobi()) +# end + +# @testset "size ($m, $n)" for n in (37, m, 63) +# @testset "algorithm $alg" for alg in algs +# n > m && alg isa LAPACK_Jacobi && continue # not supported +# A = randn(rng, T, m, n) +# S₀ = svd_vals(A) +# minmn = min(m, n) +# r = minmn - 2 + +# U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) +# @test length(S1.diag) == r +# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + +# s = 1 + sqrt(eps(real(T))) +# trunc2 = trunctol(s * S₀[r + 1]) + +# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(s * S₀[r + 1])) +# @test length(S2.diag) == r +# @test U1 ≈ U2 +# @test S1 ≈ S2 +# @test V1ᴴ ≈ V2ᴴ +# end +# end +# end diff --git a/test/amd/utilities.jl b/test/amd/utilities.jl new file mode 100644 index 00000000..61518b55 --- /dev/null +++ b/test/amd/utilities.jl @@ -0,0 +1,3 @@ +function isapproxone(A) + return (size(A, 1) == size(A, 2)) && (A ≈ MatrixAlgebraKit.one!(similar(A))) +end diff --git a/test/runtests.jl b/test/runtests.jl index 54923e85..d0414e02 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -65,3 +65,15 @@ if CUDA.functional() end end +using AMDGPU +if AMDGPU.functional() + @safetestset "AMDGPU QR" begin + include("amd/qr.jl") + end + @safetestset "AMDGPU LQ" begin + include("amd/lq.jl") + end + @safetestset "AMDGPU SVD" begin + include("amd/svd.jl") + end +end