Skip to content

Commit 9400fb7

Browse files
Add AMDGPUOffloadFactorization algorithm support (#708)
* Add AMDGPUOffloadFactorization algorithm support This commit adds support for AMD GPU-accelerated linear solving through the new AMDGPUOffloadFactorization algorithm: - Added AMDGPUOffloadFactorization struct in src/extension_algs.jl with proper error handling when AMDGPU.jl is not loaded - Created LinearSolveAMDGPUExt extension in ext/LinearSolveAMDGPUExt.jl implementing GPU-offloaded LU factorization using AMDGPU.rocSOLVER - Added AMDGPU as weak dependency and extension configuration in Project.toml - Exported AMDGPUOffloadFactorization in src/LinearSolve.jl The implementation follows the same pattern as CudaOffloadFactorization, using rocSOLVER for LU factorization and solve operations on AMD GPUs. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Rename to AMDGPUOffloadLUFactorization and add AMDGPUOffloadQRFactorization - Renamed AMDGPUOffloadFactorization to AMDGPUOffloadLUFactorization for clarity - Added AMDGPUOffloadQRFactorization for QR-based solving - Updated extension to support both LU and QR factorizations - LU uses rocSOLVER.getrf\! and getrs\! - QR uses rocSOLVER.geqrf\!, ormqr\!, and rocBLAS.trsv\! 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]> --------- Co-authored-by: Claude <[email protected]>
1 parent 93db65d commit 9400fb7

File tree

4 files changed

+116
-3
lines changed

4 files changed

+116
-3
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2828
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2929

3030
[weakdeps]
31+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3132
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3233
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33-
blis_jll = "6136c539-28a5-5bf0-87cc-b183200dce32"
3434
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3535
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
3636
CUSOLVERRF = "a8cc9031-bad2-4722-94f5-40deabb4245c"
@@ -48,8 +48,10 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4848
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
4949
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5050
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
51+
blis_jll = "6136c539-28a5-5bf0-87cc-b183200dce32"
5152

5253
[extensions]
54+
LinearSolveAMDGPUExt = "AMDGPU"
5355
LinearSolveBLISExt = ["blis_jll", "LAPACK_jll"]
5456
LinearSolveBandedMatricesExt = "BandedMatrices"
5557
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
@@ -71,12 +73,12 @@ LinearSolveSparseArraysExt = "SparseArrays"
7173
LinearSolveSparspakExt = ["SparseArrays", "Sparspak"]
7274

7375
[compat]
76+
AMDGPU = "1"
7477
AllocCheck = "0.2"
7578
Aqua = "0.8"
7679
ArrayInterface = "7.7"
7780
BandedMatrices = "1.5"
7881
BlockDiagonals = "0.2"
79-
blis_jll = "0.9.0"
8082
CUDA = "5"
8183
CUDSS = "0.4"
8284
CUSOLVERRF = "0.2.6"
@@ -126,6 +128,7 @@ StaticArraysCore = "1.4.2"
126128
Test = "1"
127129
UnPack = "1"
128130
Zygote = "0.7"
131+
blis_jll = "0.9.0"
129132
julia = "1.10"
130133

131134
[extras]

ext/LinearSolveAMDGPUExt.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
module LinearSolveAMDGPUExt
2+
3+
using AMDGPU
4+
using LinearSolve: LinearSolve, LinearCache, AMDGPUOffloadLUFactorization,
5+
AMDGPUOffloadQRFactorization, init_cacheval, OperatorAssumptions
6+
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase
7+
8+
# LU Factorization
9+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadLUFactorization;
10+
kwargs...)
11+
if cache.isfresh
12+
fact = AMDGPU.rocSOLVER.getrf!(AMDGPU.ROCArray(cache.A))
13+
cache.cacheval = fact
14+
cache.isfresh = false
15+
end
16+
17+
A_gpu, ipiv = cache.cacheval
18+
b_gpu = AMDGPU.ROCArray(cache.b)
19+
20+
AMDGPU.rocSOLVER.getrs!('N', A_gpu, ipiv, b_gpu)
21+
22+
y = Array(b_gpu)
23+
cache.u .= y
24+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
25+
end
26+
27+
function LinearSolve.init_cacheval(alg::AMDGPUOffloadLUFactorization, A, b, u, Pl, Pr,
28+
maxiters::Int, abstol, reltol, verbose::Bool,
29+
assumptions::OperatorAssumptions)
30+
AMDGPU.rocSOLVER.getrf!(AMDGPU.ROCArray(A))
31+
end
32+
33+
# QR Factorization
34+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadQRFactorization;
35+
kwargs...)
36+
if cache.isfresh
37+
A_gpu = AMDGPU.ROCArray(cache.A)
38+
tau = AMDGPU.ROCVector{eltype(A_gpu)}(undef, min(size(A_gpu)...))
39+
AMDGPU.rocSOLVER.geqrf!(A_gpu, tau)
40+
cache.cacheval = (A_gpu, tau)
41+
cache.isfresh = false
42+
end
43+
44+
A_gpu, tau = cache.cacheval
45+
b_gpu = AMDGPU.ROCArray(cache.b)
46+
47+
# Apply Q^T to b
48+
AMDGPU.rocSOLVER.ormqr!('L', 'T', A_gpu, tau, b_gpu)
49+
50+
# Solve the upper triangular system
51+
m, n = size(A_gpu)
52+
AMDGPU.rocBLAS.trsv!('U', 'N', 'N', n, A_gpu, b_gpu)
53+
54+
y = Array(b_gpu[1:n])
55+
cache.u .= y
56+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
57+
end
58+
59+
function LinearSolve.init_cacheval(alg::AMDGPUOffloadQRFactorization, A, b, u, Pl, Pr,
60+
maxiters::Int, abstol, reltol, verbose::Bool,
61+
assumptions::OperatorAssumptions)
62+
A_gpu = AMDGPU.ROCArray(A)
63+
tau = AMDGPU.ROCVector{eltype(A_gpu)}(undef, min(size(A_gpu)...))
64+
AMDGPU.rocSOLVER.geqrf!(A_gpu, tau)
65+
(A_gpu, tau)
66+
end
67+
68+
end

src/LinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
254254
export SimpleGMRES
255255

256256
export HYPREAlgorithm
257-
export CudaOffloadFactorization
257+
export CudaOffloadFactorization, AMDGPUOffloadLUFactorization, AMDGPUOffloadQRFactorization
258258
export MKLPardisoFactorize, MKLPardisoIterate
259259
export PanuaPardisoFactorize, PanuaPardisoIterate
260260
export PardisoJL

src/extension_algs.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,48 @@ struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization
8282
end
8383
end
8484

85+
"""
86+
`AMDGPUOffloadLUFactorization()`
87+
88+
An offloading technique using LU factorization to GPU-accelerate CPU-based computations on AMD GPUs.
89+
Requires a sufficiently large `A` to overcome the data transfer costs.
90+
91+
!!! note
92+
93+
Using this solver requires adding the package AMDGPU.jl, i.e. `using AMDGPU`
94+
"""
95+
struct AMDGPUOffloadLUFactorization <: LinearSolve.AbstractFactorization
96+
function AMDGPUOffloadLUFactorization()
97+
ext = Base.get_extension(@__MODULE__, :LinearSolveAMDGPUExt)
98+
if ext === nothing
99+
error("AMDGPUOffloadLUFactorization requires that AMDGPU is loaded, i.e. `using AMDGPU`")
100+
else
101+
return new{}()
102+
end
103+
end
104+
end
105+
106+
"""
107+
`AMDGPUOffloadQRFactorization()`
108+
109+
An offloading technique using QR factorization to GPU-accelerate CPU-based computations on AMD GPUs.
110+
Requires a sufficiently large `A` to overcome the data transfer costs.
111+
112+
!!! note
113+
114+
Using this solver requires adding the package AMDGPU.jl, i.e. `using AMDGPU`
115+
"""
116+
struct AMDGPUOffloadQRFactorization <: LinearSolve.AbstractFactorization
117+
function AMDGPUOffloadQRFactorization()
118+
ext = Base.get_extension(@__MODULE__, :LinearSolveAMDGPUExt)
119+
if ext === nothing
120+
error("AMDGPUOffloadQRFactorization requires that AMDGPU is loaded, i.e. `using AMDGPU`")
121+
else
122+
return new{}()
123+
end
124+
end
125+
end
126+
85127
## RFLUFactorization
86128

87129
"""

0 commit comments

Comments
 (0)