Skip to content

Commit f7ba927

Browse files
Rename to AMDGPUOffloadLUFactorization and add AMDGPUOffloadQRFactorization
- Renamed AMDGPUOffloadFactorization to AMDGPUOffloadLUFactorization for clarity - Added AMDGPUOffloadQRFactorization for QR-based solving - Updated extension to support both LU and QR factorizations - LU uses rocSOLVER.getrf\! and getrs\! - QR uses rocSOLVER.geqrf\!, ormqr\!, and rocBLAS.trsv\! 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent 5303d5b commit f7ba927

File tree

3 files changed

+67
-10
lines changed

3 files changed

+67
-10
lines changed

ext/LinearSolveAMDGPUExt.jl

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
module LinearSolveAMDGPUExt
22

33
using AMDGPU
4-
using LinearSolve: LinearSolve, LinearCache, AMDGPUOffloadFactorization,
5-
init_cacheval, OperatorAssumptions
4+
using LinearSolve: LinearSolve, LinearCache, AMDGPUOffloadLUFactorization,
5+
AMDGPUOffloadQRFactorization, init_cacheval, OperatorAssumptions
66
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase
77

8-
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadFactorization;
8+
# LU Factorization
9+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadLUFactorization;
910
kwargs...)
1011
if cache.isfresh
1112
fact = AMDGPU.rocSOLVER.getrf!(AMDGPU.ROCArray(cache.A))
@@ -23,10 +24,45 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadFact
2324
SciMLBase.build_linear_solution(alg, y, nothing, cache)
2425
end
2526

26-
function LinearSolve.init_cacheval(alg::AMDGPUOffloadFactorization, A, b, u, Pl, Pr,
27+
function LinearSolve.init_cacheval(alg::AMDGPUOffloadLUFactorization, A, b, u, Pl, Pr,
2728
maxiters::Int, abstol, reltol, verbose::Bool,
2829
assumptions::OperatorAssumptions)
2930
AMDGPU.rocSOLVER.getrf!(AMDGPU.ROCArray(A))
3031
end
3132

33+
# QR Factorization
34+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadQRFactorization;
35+
kwargs...)
36+
if cache.isfresh
37+
A_gpu = AMDGPU.ROCArray(cache.A)
38+
tau = AMDGPU.ROCVector{eltype(A_gpu)}(undef, min(size(A_gpu)...))
39+
AMDGPU.rocSOLVER.geqrf!(A_gpu, tau)
40+
cache.cacheval = (A_gpu, tau)
41+
cache.isfresh = false
42+
end
43+
44+
A_gpu, tau = cache.cacheval
45+
b_gpu = AMDGPU.ROCArray(cache.b)
46+
47+
# Apply Q^T to b
48+
AMDGPU.rocSOLVER.ormqr!('L', 'T', A_gpu, tau, b_gpu)
49+
50+
# Solve the upper triangular system
51+
m, n = size(A_gpu)
52+
AMDGPU.rocBLAS.trsv!('U', 'N', 'N', n, A_gpu, b_gpu)
53+
54+
y = Array(b_gpu[1:n])
55+
cache.u .= y
56+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
57+
end
58+
59+
function LinearSolve.init_cacheval(alg::AMDGPUOffloadQRFactorization, A, b, u, Pl, Pr,
60+
maxiters::Int, abstol, reltol, verbose::Bool,
61+
assumptions::OperatorAssumptions)
62+
A_gpu = AMDGPU.ROCArray(A)
63+
tau = AMDGPU.ROCVector{eltype(A_gpu)}(undef, min(size(A_gpu)...))
64+
AMDGPU.rocSOLVER.geqrf!(A_gpu, tau)
65+
(A_gpu, tau)
66+
end
67+
3268
end

src/LinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
254254
export SimpleGMRES
255255

256256
export HYPREAlgorithm
257-
export CudaOffloadFactorization, AMDGPUOffloadFactorization
257+
export CudaOffloadFactorization, AMDGPUOffloadLUFactorization, AMDGPUOffloadQRFactorization
258258
export MKLPardisoFactorize, MKLPardisoIterate
259259
export PanuaPardisoFactorize, PanuaPardisoIterate
260260
export PardisoJL

src/extension_algs.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,41 @@ struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization
8383
end
8484

8585
"""
86-
`AMDGPUOffloadFactorization()`
86+
`AMDGPUOffloadLUFactorization()`
8787
88-
An offloading technique used to GPU-accelerate CPU-based computations on AMD GPUs.
88+
An offloading technique using LU factorization to GPU-accelerate CPU-based computations on AMD GPUs.
8989
Requires a sufficiently large `A` to overcome the data transfer costs.
9090
9191
!!! note
9292
9393
Using this solver requires adding the package AMDGPU.jl, i.e. `using AMDGPU`
9494
"""
95-
struct AMDGPUOffloadFactorization <: LinearSolve.AbstractFactorization
96-
function AMDGPUOffloadFactorization()
95+
struct AMDGPUOffloadLUFactorization <: LinearSolve.AbstractFactorization
96+
function AMDGPUOffloadLUFactorization()
9797
ext = Base.get_extension(@__MODULE__, :LinearSolveAMDGPUExt)
9898
if ext === nothing
99-
error("AMDGPUOffloadFactorization requires that AMDGPU is loaded, i.e. `using AMDGPU`")
99+
error("AMDGPUOffloadLUFactorization requires that AMDGPU is loaded, i.e. `using AMDGPU`")
100+
else
101+
return new{}()
102+
end
103+
end
104+
end
105+
106+
"""
107+
`AMDGPUOffloadQRFactorization()`
108+
109+
An offloading technique using QR factorization to GPU-accelerate CPU-based computations on AMD GPUs.
110+
Requires a sufficiently large `A` to overcome the data transfer costs.
111+
112+
!!! note
113+
114+
Using this solver requires adding the package AMDGPU.jl, i.e. `using AMDGPU`
115+
"""
116+
struct AMDGPUOffloadQRFactorization <: LinearSolve.AbstractFactorization
117+
function AMDGPUOffloadQRFactorization()
118+
ext = Base.get_extension(@__MODULE__, :LinearSolveAMDGPUExt)
119+
if ext === nothing
120+
error("AMDGPUOffloadQRFactorization requires that AMDGPU is loaded, i.e. `using AMDGPU`")
100121
else
101122
return new{}()
102123
end

0 commit comments

Comments
 (0)