Skip to content

Commit ac2b5aa

Browse files
Merge pull request #366 from avik-pal/ap/simplegmres
Add SimpleGMRES implementation
2 parents 60ae26a + 38e3401 commit ac2b5aa

12 files changed

+738
-8
lines changed

.JuliaFormatter.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
style = "sciml"
2-
format_markdown = true
2+
format_markdown = true
3+
annotate_untyped_fields_with_any = false

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
Manifest.toml
66

77
*.swp
8+
.vscode
9+
wip

Project.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "2.6.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
89
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
910
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1011
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
@@ -28,32 +29,38 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2829
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2930

3031
[weakdeps]
32+
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
3133
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3234
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3335
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
36+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3437
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
3538
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
3639
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
3740
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
3841

3942
[extensions]
43+
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
4044
LinearSolveCUDAExt = "CUDA"
4145
LinearSolveHYPREExt = "HYPRE"
4246
LinearSolveIterativeSolversExt = "IterativeSolvers"
47+
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
4348
LinearSolveKrylovKitExt = "KrylovKit"
4449
LinearSolveMKLExt = "MKL_jll"
4550
LinearSolveMetalExt = "Metal"
4651
LinearSolvePardisoExt = "Pardiso"
4752

4853
[compat]
4954
ArrayInterface = "7.4.11"
55+
BlockDiagonals = "0.1"
5056
DocStringExtensions = "0.8, 0.9"
5157
EnumX = "1"
5258
FastLapackInterface = "1, 2"
5359
GPUArraysCore = "0.1"
5460
HYPRE = "1.4.0"
5561
IterativeSolvers = "0.9.2"
5662
KLU = "0.3.0, 0.4"
63+
KernelAbstractions = "0.9"
5764
Krylov = "0.9"
5865
KrylovKit = "0.5, 0.6"
5966
PrecompileTools = "1"
@@ -69,20 +76,22 @@ UnPack = "1"
6976
julia = "1.6"
7077

7178
[extras]
79+
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
7280
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
7381
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
7482
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7583
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
7684
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
85+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
7786
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
7887
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
79-
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
8088
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
89+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
8190
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
8291
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8392
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
8493
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
8594
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8695

8796
[targets]
88-
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll"]
97+
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]

docs/src/solvers/solvers.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ choice of Krylov method should be the one most constrained to the type of operat
7272
has, for example if positive definite then `Krylov_CG()`, but if no good properties then
7373
use `Krylov_GMRES()`.
7474

75+
!!! tip
76+
77+
If your materialized operator is a uniform block diagonal matrix, then you can use
78+
`SimpleGMRES(; blocksize = <known block size>)` to further improve performance.
79+
This often shows up in Neural Networks where the Jacobian wrt the Inputs (almost always)
80+
is a Uniform Block Diagonal matrix of Block Size = size of the input divided by the
81+
batch size.
82+
7583
## Full List of Methods
7684

7785
### RecursiveFactorization.jl
@@ -106,6 +114,7 @@ LinearSolve.jl contains some linear solvers built in for specailized cases.
106114
```@docs
107115
SimpleLUFactorization
108116
DiagonalFactorization
117+
SimpleGMRES
109118
```
110119

111120
### FastLapackInterface.jl

ext/LinearSolveBlockDiagonalsExt.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
module LinearSolveBlockDiagonalsExt
2+
3+
using LinearSolve, BlockDiagonals
4+
5+
function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, args...;
6+
kwargs...)
7+
@assert ndims(A) == 2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2."
8+
# We need to perform this check even when `zeroinit == true`, since the type of the
9+
# cache is dependent on whether we are able to use the specialized dispatch.
10+
bsizes = blocksizes(A)
11+
usize = first(first(bsizes))
12+
uniform_blocks = true
13+
for bsize in bsizes
14+
if bsize[1] != usize || bsize[2] != usize
15+
uniform_blocks = false
16+
break
17+
end
18+
end
19+
# Can't help but perform dynamic dispatch here
20+
return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, args...;
21+
blocksize = usize, kwargs...)
22+
end
23+
24+
end
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
module LinearSolveKernelAbstractionsExt
2+
3+
using LinearSolve, KernelAbstractions
4+
5+
LinearSolve.__is_extension_loaded(::Val{:KernelAbstractions}) = true
6+
7+
using GPUArraysCore
8+
9+
function LinearSolve._fast_sym_givens!(c, s, R, nr::Int, inner_iter::Int, bsize::Int, Hbis)
10+
backend = get_backend(Hbis)
11+
kernel! = __fast_sym_givens_kernel!(backend)
12+
kernel!(c[inner_iter], s[inner_iter], R[nr + inner_iter], Hbis; ndrange=bsize)
13+
return c, s, R
14+
end
15+
16+
@kernel function __fast_sym_givens_kernel!(c, s, R, @Const(Hbis))
17+
idx = @index(Global)
18+
@inbounds _c, _s, _ρ = LinearSolve._sym_givens(R[idx], Hbis[idx])
19+
@inbounds c[idx] = _c
20+
@inbounds s[idx] = _s
21+
@inbounds R[idx] =
22+
end
23+
24+
end

src/LinearSolve.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,16 @@ PrecompileTools.@recompile_invalidations begin
2828
import InteractiveUtils
2929

3030
using LinearAlgebra: BlasInt, LU
31-
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
31+
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
3232
@blasfunc, chkargsok
3333

3434
import GPUArraysCore
3535
import Preferences
36+
import ConcreteStructs: @concrete
3637

3738
# wrap
3839
import Krylov
39-
40+
4041
using SciMLBase
4142
end
4243

@@ -62,6 +63,11 @@ _isidentity_struct(λ::Number) = isone(λ)
6263
_isidentity_struct(A::UniformScaling) = isone(A.λ)
6364
_isidentity_struct(::SciMLOperators.IdentityOperator) = true
6465

66+
# Dispatch Friendly way to check if an extension is loaded
67+
__is_extension_loaded(::Val) = false
68+
69+
function _fast_sym_givens! end
70+
6571
# Code
6672

6773
const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)
@@ -92,6 +98,7 @@ end
9298
include("common.jl")
9399
include("factorization.jl")
94100
include("simplelu.jl")
101+
include("simplegmres.jl")
95102
include("iterative_wrappers.jl")
96103
include("preconditioners.jl")
97104
include("solve_function.jl")
@@ -176,6 +183,8 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
176183
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
177184
KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES
178185

186+
export SimpleGMRES
187+
179188
export HYPREAlgorithm
180189
export CudaOffloadFactorization
181190
export MKLPardisoFactorize, MKLPardisoIterate

src/iterative_wrappers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
253253
Krylov.solve!(args...; M = M,
254254
kwargs...)
255255
elseif cache.cacheval isa Krylov.GmresSolver
256-
Krylov.solve!(args...; M = M, N = N,
256+
Krylov.solve!(args...; M = M, N = N, restart = alg.gmres_restart > 0,
257257
kwargs...)
258258
elseif cache.cacheval isa Krylov.BicgstabSolver
259259
Krylov.solve!(args...; M = M, N = N,

0 commit comments

Comments
 (0)