Skip to content

Commit 19312b1

Browse files
Avik Palavik-pal
authored andcommitted
Use KA.jl for faster sym_givens
1 parent bc74893 commit 19312b1

File tree

4 files changed

+50
-25
lines changed

4 files changed

+50
-25
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
3333
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3434
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3535
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
36+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3637
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
3738
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
3839
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
@@ -43,6 +44,7 @@ LinearSolveBlockDiagonalsExt = "BlockDiagonals"
4344
LinearSolveCUDAExt = "CUDA"
4445
LinearSolveHYPREExt = "HYPRE"
4546
LinearSolveIterativeSolversExt = "IterativeSolvers"
47+
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
4648
LinearSolveKrylovKitExt = "KrylovKit"
4749
LinearSolveMKLExt = "MKL_jll"
4850
LinearSolveMetalExt = "Metal"
@@ -58,6 +60,7 @@ GPUArraysCore = "0.1"
5860
HYPRE = "1.4.0"
5961
IterativeSolvers = "0.9.2"
6062
KLU = "0.3.0, 0.4"
63+
KernelAbstractions = "0.9"
6164
Krylov = "0.9"
6265
KrylovKit = "0.5, 0.6"
6366
PrecompileTools = "1"
@@ -79,6 +82,7 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
7982
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
8083
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
8184
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
85+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
8286
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
8387
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
8488
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
module LinearSolveKernelAbstractionsExt
2+
3+
using LinearSolve, KernelAbstractions
4+
5+
LinearSolve.__is_extension_loaded(::Val{:KernelAbstractions}) = true
6+
7+
using GPUArraysCore
8+
9+
function LinearSolve._fast_sym_givens!(c, s, R, nr::Int, inner_iter::Int, bsize::Int, Hbis)
10+
backend = get_backend(Hbis)
11+
kernel! = __fast_sym_givens_kernel!(backend)
12+
kernel!(c[inner_iter], s[inner_iter], R[nr + inner_iter], Hbis; ndrange=bsize)
13+
return c, s, R
14+
end
15+
16+
@kernel function __fast_sym_givens_kernel!(c, s, R, @Const(Hbis))
17+
idx = @index(Global)
18+
@inbounds _c, _s, _ρ = LinearSolve._sym_givens(R[idx], Hbis[idx])
19+
@inbounds c[idx] = _c
20+
@inbounds s[idx] = _s
21+
@inbounds R[idx] =
22+
end
23+
24+
end

src/LinearSolve.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ _isidentity_struct(λ::Number) = isone(λ)
5656
_isidentity_struct(A::UniformScaling) = isone(A.λ)
5757
_isidentity_struct(::SciMLOperators.IdentityOperator) = true
5858

59+
# Dispatch Friendly way to check if an extension is loaded
60+
__is_extension_loaded(::Val) = false
61+
62+
function _fast_sym_givens! end
63+
5964
# Code
6065

6166
const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)

src/simplegmres.jl

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,19 @@ function _sym_givens(a::T, b::T) where {T <: AbstractFloat}
105105
return (c, s, ρ)
106106
end
107107

108+
function _sym_givens!(c, s, R, nr::Int, inner_iter::Int, bsize::Int, Hbis)
109+
if __is_extension_loaded(Val(:KernelAbstractions))
110+
return _fast_sym_givens!(c, s, R, nr, inner_iter, bsize, Hbis)
111+
end
112+
__res = _sym_givens.(R[nr + inner_iter], Hbis)
113+
GPUArraysCore.@allowscalar foreach(1:bsize) do i
114+
c[inner_iter][i] = __res[i][1]
115+
s[inner_iter][i] = __res[i][2]
116+
R[nr + inner_iter][i] = __res[i][3]
117+
end
118+
return c, s, R
119+
end
120+
108121
_no_preconditioner(::Nothing) = true
109122
_no_preconditioner(::IdentityOperator) = true
110123
_no_preconditioner(::UniformScaling) = true
@@ -221,15 +234,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache)
221234
while !(solved || tired || breakdown)
222235
# Initialize workspace.
223236
nr = 0 # Number of coefficients stored in Rₖ.
224-
#= TODO: Check that not zeroing out doesn't lead to incorrect results.
225-
foreach(V) do v
226-
v .= zero(T) # Orthogonal basis of Kₖ(MAN, Mr₀).
227-
end
228-
s .= zero(T) # Givens sines used for the factorization QₖRₖ = Hₖ₊₁.ₖ.
229-
c .= zero(T) # Givens cosines used for the factorization QₖRₖ = Hₖ₊₁.ₖ.
230-
R .= zero(T) # Upper triangular matrix Rₖ.
231-
z .= zero(T) # Right-hand of the least squares problem min ‖Hₖ₊₁.ₖyₖ - βe₁‖₂.
232-
=#
237+
# TODO: Check that not zeroing out doesn't lead to incorrect results.
233238

234239
if restart
235240
xr .= zero(T) # xr === Δx when restart is set to true
@@ -517,13 +522,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache)
517522
# Compute and apply current Givens reflection Ωₖ.
518523
# [cₖ sₖ] [ r̄ₖ.ₖ ] = [rₖ.ₖ]
519524
# [s̄ₖ -cₖ] [hₖ₊₁.ₖ] [ 0 ]
520-
# FIXME: Write inplace kernel
521-
__res = _sym_givens.(R[nr + inner_iter], Hbis)
522-
foreach(1:bsize) do i
523-
c[inner_iter][i] = __res[i][1]
524-
s[inner_iter][i] = __res[i][2]
525-
R[nr + inner_iter][i] = __res[i][3]
526-
end
525+
_sym_givens!(c, s, R, nr, inner_iter, bsize, Hbis)
527526

528527
# Update zₖ = (Qₖ)ᴴβe₁
529528
ζₖ₊₁ = conj.(s[inner_iter]) .* z[inner_iter]
@@ -567,15 +566,8 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache)
567566
pos = pos - j + 1 # position of rᵢ.ⱼ₋₁
568567
end
569568
# Rₖ can be singular if the system is inconsistent
570-
# FIXME: Write with broadcasting
571-
GPUArraysCore.@allowscalar foreach(1:bsize) do B
572-
if abs(R[pos][B]) btol
573-
y[i][B] = zero(T)
574-
inconsistent = true
575-
else
576-
y[i][B] /= R[pos][B]
577-
end
578-
end
569+
y[i] .= ifelse.(abs.(R[pos]) .≤ btol, zero(T), y[i] ./ R[pos]) # yᵢ ← yᵢ / rᵢᵢ
570+
inconsistent = any(abs.(R[pos]) .≤ btol)
579571
end
580572

581573
# Form xₖ = NVₖyₖ

0 commit comments

Comments
 (0)