Skip to content

Commit 5303d5b

Browse files
Add AMDGPUOffloadFactorization algorithm support
This commit adds support for AMD GPU-accelerated linear solving through the new AMDGPUOffloadFactorization algorithm: - Added AMDGPUOffloadFactorization struct in src/extension_algs.jl with proper error handling when AMDGPU.jl is not loaded - Created LinearSolveAMDGPUExt extension in ext/LinearSolveAMDGPUExt.jl implementing GPU-offloaded LU factorization using AMDGPU.rocSOLVER - Added AMDGPU as weak dependency and extension configuration in Project.toml - Exported AMDGPUOffloadFactorization in src/LinearSolve.jl The implementation follows the same pattern as CudaOffloadFactorization, using rocSOLVER for LU factorization and solve operations on AMD GPUs. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 93db65d commit 5303d5b

File tree

4 files changed

+59
-3
lines changed

4 files changed

+59
-3
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2828
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2929

3030
[weakdeps]
31+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3132
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3233
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33-
blis_jll = "6136c539-28a5-5bf0-87cc-b183200dce32"
3434
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3535
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
3636
CUSOLVERRF = "a8cc9031-bad2-4722-94f5-40deabb4245c"
@@ -48,8 +48,10 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4848
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
4949
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5050
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
51+
blis_jll = "6136c539-28a5-5bf0-87cc-b183200dce32"
5152

5253
[extensions]
54+
LinearSolveAMDGPUExt = "AMDGPU"
5355
LinearSolveBLISExt = ["blis_jll", "LAPACK_jll"]
5456
LinearSolveBandedMatricesExt = "BandedMatrices"
5557
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
@@ -71,12 +73,12 @@ LinearSolveSparseArraysExt = "SparseArrays"
7173
LinearSolveSparspakExt = ["SparseArrays", "Sparspak"]
7274

7375
[compat]
76+
AMDGPU = "1"
7477
AllocCheck = "0.2"
7578
Aqua = "0.8"
7679
ArrayInterface = "7.7"
7780
BandedMatrices = "1.5"
7881
BlockDiagonals = "0.2"
79-
blis_jll = "0.9.0"
8082
CUDA = "5"
8183
CUDSS = "0.4"
8284
CUSOLVERRF = "0.2.6"
@@ -126,6 +128,7 @@ StaticArraysCore = "1.4.2"
126128
Test = "1"
127129
UnPack = "1"
128130
Zygote = "0.7"
131+
blis_jll = "0.9.0"
129132
julia = "1.10"
130133

131134
[extras]

ext/LinearSolveAMDGPUExt.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module LinearSolveAMDGPUExt
2+
3+
using AMDGPU
4+
using LinearSolve: LinearSolve, LinearCache, AMDGPUOffloadFactorization,
5+
init_cacheval, OperatorAssumptions
6+
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase
7+
8+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadFactorization;
9+
kwargs...)
10+
if cache.isfresh
11+
fact = AMDGPU.rocSOLVER.getrf!(AMDGPU.ROCArray(cache.A))
12+
cache.cacheval = fact
13+
cache.isfresh = false
14+
end
15+
16+
A_gpu, ipiv = cache.cacheval
17+
b_gpu = AMDGPU.ROCArray(cache.b)
18+
19+
AMDGPU.rocSOLVER.getrs!('N', A_gpu, ipiv, b_gpu)
20+
21+
y = Array(b_gpu)
22+
cache.u .= y
23+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
24+
end
25+
26+
function LinearSolve.init_cacheval(alg::AMDGPUOffloadFactorization, A, b, u, Pl, Pr,
27+
maxiters::Int, abstol, reltol, verbose::Bool,
28+
assumptions::OperatorAssumptions)
29+
AMDGPU.rocSOLVER.getrf!(AMDGPU.ROCArray(A))
30+
end
31+
32+
end

src/LinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
254254
export SimpleGMRES
255255

256256
export HYPREAlgorithm
257-
export CudaOffloadFactorization
257+
export CudaOffloadFactorization, AMDGPUOffloadFactorization
258258
export MKLPardisoFactorize, MKLPardisoIterate
259259
export PanuaPardisoFactorize, PanuaPardisoIterate
260260
export PardisoJL

src/extension_algs.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,27 @@ struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization
8282
end
8383
end
8484

85+
"""
86+
`AMDGPUOffloadFactorization()`
87+
88+
An offloading technique used to GPU-accelerate CPU-based computations on AMD GPUs.
89+
Requires a sufficiently large `A` to overcome the data transfer costs.
90+
91+
!!! note
92+
93+
Using this solver requires adding the package AMDGPU.jl, i.e. `using AMDGPU`
94+
"""
95+
struct AMDGPUOffloadFactorization <: LinearSolve.AbstractFactorization
96+
function AMDGPUOffloadFactorization()
97+
ext = Base.get_extension(@__MODULE__, :LinearSolveAMDGPUExt)
98+
if ext === nothing
99+
error("AMDGPUOffloadFactorization requires that AMDGPU is loaded, i.e. `using AMDGPU`")
100+
else
101+
return new{}()
102+
end
103+
end
104+
end
105+
85106
## RFLUFactorization
86107

87108
"""

0 commit comments

Comments
 (0)