1+ module LinearSolveAMDGPUExt
2+
3+ using AMDGPU
4+ using LinearSolve: LinearSolve, LinearCache, AMDGPUOffloadLUFactorization,
5+ AMDGPUOffloadQRFactorization, init_cacheval, OperatorAssumptions
6+ using LinearSolve. LinearAlgebra, LinearSolve. SciMLBase
7+
8+ # LU Factorization
9+ function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: AMDGPUOffloadLUFactorization ;
10+ kwargs... )
11+ if cache. isfresh
12+ fact = AMDGPU. rocSOLVER. getrf! (AMDGPU. ROCArray (cache. A))
13+ cache. cacheval = fact
14+ cache. isfresh = false
15+ end
16+
17+ A_gpu, ipiv = cache. cacheval
18+ b_gpu = AMDGPU. ROCArray (cache. b)
19+
20+ AMDGPU. rocSOLVER. getrs! (' N' , A_gpu, ipiv, b_gpu)
21+
22+ y = Array (b_gpu)
23+ cache. u .= y
24+ SciMLBase. build_linear_solution (alg, y, nothing , cache)
25+ end
26+
27+ function LinearSolve. init_cacheval (alg:: AMDGPUOffloadLUFactorization , A, b, u, Pl, Pr,
28+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
29+ assumptions:: OperatorAssumptions )
30+ AMDGPU. rocSOLVER. getrf! (AMDGPU. ROCArray (A))
31+ end
32+
33+ # QR Factorization
34+ function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: AMDGPUOffloadQRFactorization ;
35+ kwargs... )
36+ if cache. isfresh
37+ A_gpu = AMDGPU. ROCArray (cache. A)
38+ tau = AMDGPU. ROCVector {eltype(A_gpu)} (undef, min (size (A_gpu)... ))
39+ AMDGPU. rocSOLVER. geqrf! (A_gpu, tau)
40+ cache. cacheval = (A_gpu, tau)
41+ cache. isfresh = false
42+ end
43+
44+ A_gpu, tau = cache. cacheval
45+ b_gpu = AMDGPU. ROCArray (cache. b)
46+
47+ # Apply Q^T to b
48+ AMDGPU. rocSOLVER. ormqr! (' L' , ' T' , A_gpu, tau, b_gpu)
49+
50+ # Solve the upper triangular system
51+ m, n = size (A_gpu)
52+ AMDGPU. rocBLAS. trsv! (' U' , ' N' , ' N' , n, A_gpu, b_gpu)
53+
54+ y = Array (b_gpu[1 : n])
55+ cache. u .= y
56+ SciMLBase. build_linear_solution (alg, y, nothing , cache)
57+ end
58+
59+ function LinearSolve. init_cacheval (alg:: AMDGPUOffloadQRFactorization , A, b, u, Pl, Pr,
60+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
61+ assumptions:: OperatorAssumptions )
62+ A_gpu = AMDGPU. ROCArray (A)
63+ tau = AMDGPU. ROCVector {eltype(A_gpu)} (undef, min (size (A_gpu)... ))
64+ AMDGPU. rocSOLVER. geqrf! (A_gpu, tau)
65+ (A_gpu, tau)
66+ end
67+
68+ end
0 commit comments