Skip to content
34 changes: 34 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
env:
SECRET_CODECOV_TOKEN: "MH6hHjQi7vG2V1Yfotv5/z5Dkx1k5SdyGYlGTFXiQr22XksJgsXaBuvFKUrjC7JwcpBsOVU8103LuMKl3m7VJ35WzHZrOssYycVbdGcb2kloc6xvUOsN2R5BrhCQ4Pii0l6ZeVRjCnZVkcmb0Rf4glGFyfibCrqniry8RLhblsuFKFsijRK4OxiWYEs1IvUulN+ER8tEsEtw4+ZqC5nbLGMSnUG/saPkDQOVIBscvikbKEnBcCXBheGPktF+Y/cy/1Xa+FiBPoZcypwTeAjKG1g0MqyHXjaYekb/7fekaj+hukGaeJSCXxY8KEb2IZCh+Y36Tp6y6qsIp/AdtEnCpQ==;U2FsdGVkX18WQxvGLspPwzC4aDe+U7TXU+itebTbgh8LUkE6GukxxReHYiDZ6IrBiVvSGTVJMquW0c8KsOI1pw=="

steps:
- label: "Julia v1"
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-test#v1: ~
- JuliaCI/julia-coverage#v1:
dirs:
- src
- ext
agents:
queue: "juliagpu"
cuda: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 30

steps:
- label: "Julia LTS"
plugins:
- JuliaCI/julia#v1:
version: "1.10" # "lts" isn't valid
- JuliaCI/julia-test#v1: ~
- JuliaCI/julia-coverage#v1:
dirs:
- src
- ext
agents:
queue: "juliagpu"
cuda: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 30
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
MatrixAlgebraKitCUDAExt = "CUDA"

[compat]
Aqua = "0.6, 0.7, 0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
CUDA = "5"
JET = "0.9"
LinearAlgebra = "1"
SafeTestsets = "0.1"
Expand All @@ -36,5 +39,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore",
"ChainRulesTestUtils", "StableRNGs", "Zygote"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA"]
28 changes: 28 additions & 0 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module MatrixAlgebraKitCUDAExt

using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
using CUDA
using LinearAlgebra
using LinearAlgebra: BlasFloat

include("yacusolver.jl")
include("implementations/qr.jl")
include("implementations/svd.jl")

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
return CUSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
return CUSOLVER_QRIteration(; kwargs...)
end

end
69 changes: 69 additions & 0 deletions ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# CUSOLVER QR implementation
function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR)
check_input(qr_full!, A, QR)
Q, R = QR
_cusolver_qr!(A, Q, R; alg.kwargs...)
return Q, R
end
function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR)
check_input(qr_compact!, A, QR)
Q, R = QR
_cusolver_qr!(A, Q, R; alg.kwargs...)
return Q, R
end
function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::CUSOLVER_HouseholderQR)
check_input(qr_null!, A, N)
_cusolver_qr_null!(A, N; alg.kwargs...)
return N
end

function _cusolver_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive=false, blocksize=1)
blocksize > 1 &&
throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition"))
m, n = size(A)
minmn = min(m, n)
computeR = length(R) > 0
inplaceQ = Q === A
if inplaceQ && (computeR || positive || m < n)
throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required and using `positive=false`"))
end

A, τ = YACUSOLVER.geqrf!(A)
if inplaceQ
Q = YACUSOLVER.ungqr!(A, τ)
else
Q = YACUSOLVER.unmqr!('L', 'N', A, τ, one!(Q))
end
# henceforth, τ is no longer needed and can be reused

if positive # already fix Q even if we do not need R
# TODO: report that `lmul!` and `rmul!` with `Diagonal` don't work with CUDA
τ .= sign_safe.(diagview(A))
Qf = view(Q, 1:m, 1:minmn) # first minmn columns of Q
Qf .= Qf .* transpose(τ)
end

if computeR
R̃ = uppertriangular!(view(A, axes(R)...))
if positive
R̃f = view(R̃, 1:minmn, 1:n) # first minmn rows of R
R̃f .= conj.(τ) .* R̃f
end
copyto!(R, R̃)
end
return Q, R
end

function _cusolver_qr_null!(A::AbstractMatrix, N::AbstractMatrix;
positive=false, blocksize=1)
blocksize > 1 &&
throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition"))
m, n = size(A)
minmn = min(m, n)
fill!(N, zero(eltype(N)))
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
A, τ = YACUSOLVER.geqrf!(A)
N = YACUSOLVER.unmqr!('L', 'N', A, τ, N)
return N
end
108 changes: 108 additions & 0 deletions ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
CUSOLVER_SVDPolar,
CUSOLVER_Jacobi}

# CUSOLVER SVD implementation
function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm)
check_input(svd_full!, A, USVᴴ)
U, S, Vᴴ = USVᴴ
fill!(S, zero(eltype(S)))
m, n = size(A)
minmn = min(m, n)
if alg isa CUSOLVER_QRIteration
isempty(alg.kwargs) ||
throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments"))
YACUSOLVER.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
elseif alg isa CUSOLVER_SVDPolar
YACUSOLVER.Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
elseif alg isa CUSOLVER_Jacobi
YACUSOLVER.gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
# elseif alg isa LAPACK_Bisection
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
# elseif alg isa LAPACK_Jacobi
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
else
throw(ArgumentError("Unsupported SVD algorithm"))
end
diagview(S) .= view(S, 1:minmn, 1)
view(S, 2:minmn, 1) .= zero(eltype(S))
# TODO: make this controllable using a `gaugefix` keyword argument
for j in 1:max(m, n)
if j <= minmn
u = view(U, :, j)
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(u)))
u .*= s
v .*= conj(s)
elseif j <= m
u = view(U, :, j)
s = conj(sign(_argmaxabs(u)))
u .*= s
else
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(v)))
v .*= s
end
end
return USVᴴ
end

function MatrixAlgebraKit.svd_compact!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm)
check_input(svd_compact!, A, USVᴴ)
U, S, Vᴴ = USVᴴ
if alg isa CUSOLVER_QRIteration
isempty(alg.kwargs) ||
throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments"))
YACUSOLVER.gesvd!(A, S.diag, U, Vᴴ)
elseif alg isa CUSOLVER_SVDPolar
YACUSOLVER.Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
elseif alg isa CUSOLVER_Jacobi
YACUSOLVER.gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...)
# elseif alg isa LAPACK_DivideAndConquer
# isempty(alg.kwargs) ||
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
# YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
# elseif alg isa LAPACK_Bisection
# YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...)
else
throw(ArgumentError("Unsupported SVD algorithm"))
end
# TODO: make this controllable using a `gaugefix` keyword argument
for j in 1:size(U, 2)
u = view(U, :, j)
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(u)))
u .*= s
v .*= conj(s)
end
return USVᴴ
end
_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))
_largest(x, y) = abs(x) < abs(y) ? y : x

function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm)
check_input(svd_vals!, A, S)
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
if alg isa CUSOLVER_QRIteration
isempty(alg.kwargs) ||
throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments"))
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
elseif alg isa CUSOLVER_SVDPolar
YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
elseif alg isa CUSOLVER_Jacobi
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; alg.kwargs...)
# elseif alg isa LAPACK_DivideAndConquer
# isempty(alg.kwargs) ||
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
# YALAPACK.gesdd!(A, S, U, Vᴴ)
# elseif alg isa LAPACK_Bisection
# YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...)
# elseif alg isa LAPACK_Jacobi
# isempty(alg.kwargs) ||
# throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
# YALAPACK.gesvj!(A, S, U, Vᴴ)
else
throw(ArgumentError("Unsupported SVD algorithm"))
end
return S
end
Loading
Loading