-
-
Notifications
You must be signed in to change notification settings - Fork 72
Add AMDGPUOffloadFactorization algorithm support #708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This commit adds support for AMD GPU-accelerated linear solving through the new AMDGPUOffloadFactorization algorithm: - Added AMDGPUOffloadFactorization struct in src/extension_algs.jl with proper error handling when AMDGPU.jl is not loaded - Created LinearSolveAMDGPUExt extension in ext/LinearSolveAMDGPUExt.jl implementing GPU-offloaded LU factorization using AMDGPU.rocSOLVER - Added AMDGPU as weak dependency and extension configuration in Project.toml - Exported AMDGPUOffloadFactorization in src/LinearSolve.jl The implementation follows the same pattern as CudaOffloadFactorization, using rocSOLVER for LU factorization and solve operations on AMD GPUs. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
…zation - Renamed AMDGPUOffloadFactorization to AMDGPUOffloadLUFactorization for clarity - Added AMDGPUOffloadQRFactorization for QR-based solving - Updated extension to support both LU and QR factorizations - LU uses rocSOLVER.getrf\! and getrs\! - QR uses rocSOLVER.geqrf\!, ormqr\!, and rocBLAS.trsv\! 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
Updated the PR with the following changes:
The implementation now provides two factorization options for AMD GPU offloading, allowing users to choose based on their numerical stability and performance requirements. |
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase | ||
|
||
# LU Factorization | ||
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadLUFactorization; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadLUFactorization; | |
function SciMLBase.solve!( | |
cache::LinearSolve.LinearCache, alg::AMDGPUOffloadLUFactorization; |
cache.cacheval = fact | ||
cache.isfresh = false | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
|
||
A_gpu, ipiv = cache.cacheval | ||
b_gpu = AMDGPU.ROCArray(cache.b) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
b_gpu = AMDGPU.ROCArray(cache.b) | ||
|
||
AMDGPU.rocSOLVER.getrs!('N', A_gpu, ipiv, b_gpu) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | ||
|
||
# QR Factorization | ||
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadQRFactorization; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::AMDGPUOffloadQRFactorization; | |
function SciMLBase.solve!( | |
cache::LinearSolve.LinearCache, alg::AMDGPUOffloadQRFactorization; |
cache.cacheval = (A_gpu, tau) | ||
cache.isfresh = false | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
|
||
A_gpu, tau = cache.cacheval | ||
b_gpu = AMDGPU.ROCArray(cache.b) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
|
||
# Apply Q^T to b | ||
AMDGPU.rocSOLVER.ormqr!('L', 'T', A_gpu, tau, b_gpu) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
# Solve the upper triangular system | ||
m, n = size(A_gpu) | ||
AMDGPU.rocBLAS.trsv!('U', 'N', 'N', n, A_gpu, b_gpu) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
(A_gpu, tau) | ||
end | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
Summary
This PR adds support for AMD GPU-accelerated linear solving through the new
AMDGPUOffloadFactorization
algorithm:• Added
AMDGPUOffloadFactorization
struct insrc/extension_algs.jl
with proper error handling when AMDGPU.jl is not loaded• Created
LinearSolveAMDGPUExt
extension inext/LinearSolveAMDGPUExt.jl
implementing GPU-offloaded LU factorization using AMDGPU.rocSOLVER• Added AMDGPU as weak dependency and extension configuration in
Project.toml
• Exported
AMDGPUOffloadFactorization
insrc/LinearSolve.jl
Implementation Details
The implementation follows the same pattern as
CudaOffloadFactorization
, usingrocSOLVER.getrf\!
for LU factorization androcSOLVER.getrs\!
for solve operations on AMD GPUs via ROCArrays. The algorithm provides GPU acceleration for sufficiently large matrices where the computation benefits outweigh the data transfer costs.Test plan
🤖 Generated with Claude Code