Skip to content

Commit a9c5ed6

Browse files
committed
Create extension for AMD and generalize GPU wrappers
1 parent a8d9401 commit a9c5ed6

File tree

17 files changed

+1194
-186
lines changed

17 files changed

+1194
-186
lines changed

.buildkite/pipeline.yml

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ env:
22
SECRET_CODECOV_TOKEN: "MH6hHjQi7vG2V1Yfotv5/z5Dkx1k5SdyGYlGTFXiQr22XksJgsXaBuvFKUrjC7JwcpBsOVU8103LuMKl3m7VJ35WzHZrOssYycVbdGcb2kloc6xvUOsN2R5BrhCQ4Pii0l6ZeVRjCnZVkcmb0Rf4glGFyfibCrqniry8RLhblsuFKFsijRK4OxiWYEs1IvUulN+ER8tEsEtw4+ZqC5nbLGMSnUG/saPkDQOVIBscvikbKEnBcCXBheGPktF+Y/cy/1Xa+FiBPoZcypwTeAjKG1g0MqyHXjaYekb/7fekaj+hukGaeJSCXxY8KEb2IZCh+Y36Tp6y6qsIp/AdtEnCpQ==;U2FsdGVkX18WQxvGLspPwzC4aDe+U7TXU+itebTbgh8LUkE6GukxxReHYiDZ6IrBiVvSGTVJMquW0c8KsOI1pw=="
33

44
steps:
5-
- label: "Julia v1"
5+
- label: "Julia v1 -- CUDA"
66
plugins:
77
- JuliaCI/julia#v1:
88
version: "1"
@@ -17,8 +17,7 @@ steps:
1717
if: build.message !~ /\[skip tests\]/
1818
timeout_in_minutes: 30
1919

20-
steps:
21-
- label: "Julia LTS"
20+
- label: "Julia LTS -- CUDA"
2221
plugins:
2322
- JuliaCI/julia#v1:
2423
version: "1.10" # "lts" isn't valid
@@ -32,3 +31,35 @@ steps:
3231
cuda: "*"
3332
if: build.message !~ /\[skip tests\]/
3433
timeout_in_minutes: 30
34+
35+
- label: "Julia v1 -- AMDGPU"
36+
plugins:
37+
- JuliaCI/julia#v1:
38+
version: "1"
39+
- JuliaCI/julia-test#v1: ~
40+
- JuliaCI/julia-coverage#v1:
41+
dirs:
42+
- src
43+
- ext
44+
agents:
45+
queue: "juliagpu"
46+
rocm: "*"
47+
rocmgpu: "*"
48+
if: build.message !~ /\[skip tests\]/
49+
timeout_in_minutes: 30
50+
51+
- label: "Julia LTS -- AMDGPU"
52+
plugins:
53+
- JuliaCI/julia#v1:
54+
version: "1.10" # "lts" isn't valid
55+
- JuliaCI/julia-test#v1: ~
56+
- JuliaCI/julia-coverage#v1:
57+
dirs:
58+
- src
59+
- ext
60+
agents:
61+
queue: "juliagpu"
62+
rocm: "*"
63+
rocmgpu: "*"
64+
if: build.message !~ /\[skip tests\]/
65+
timeout_in_minutes: 30

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@ version = "0.2.5"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

99
[weakdeps]
10+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1011
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1112
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1213

1314
[extensions]
1415
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
16+
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
1517
MatrixAlgebraKitCUDAExt = "CUDA"
1618

1719
[compat]
20+
AMDGPU = "2"
1821
Aqua = "0.6, 0.7, 0.8"
1922
ChainRulesCore = "1"
2023
ChainRulesTestUtils = "1"
@@ -39,4 +42,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
3942
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4043

4144
[targets]
42-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA"]
45+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
module MatrixAlgebraKitAMDGPUExt
2+
3+
using MatrixAlgebraKit
4+
using MatrixAlgebraKit: @algdef, Algorithm, check_input
5+
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
6+
using MatrixAlgebraKit: diagview, sign_safe
7+
using MatrixAlgebraKit: LQViaTransposedQR
8+
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
9+
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
10+
using AMDGPU
11+
using LinearAlgebra
12+
using LinearAlgebra: BlasFloat
13+
14+
include("yarocsolver.jl")
15+
16+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
17+
return ROCSOLVER_HouseholderQR(; kwargs...)
18+
end
19+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
20+
qr_alg = ROCSOLVER_HouseholderQR(; kwargs...)
21+
return LQViaTransposedQR(qr_alg)
22+
end
23+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
24+
return ROCSOLVER_QRIteration(; kwargs...)
25+
end
26+
27+
_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
28+
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)
29+
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) = YArocSOLVER.unmqr!(side, trans, A, τ, C)
30+
_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) = YArocSOLVER.gesvd!(A, S, U, Vᴴ)
31+
# not yet supported
32+
#_gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
33+
_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
34+
35+
end

0 commit comments

Comments
 (0)