Skip to content

Commit 8547c19

Browse files
committed
Working version for GMRES with blocksize
1 parent 786f9b1 commit 8547c19

File tree

3 files changed

+105
-16
lines changed

3 files changed

+105
-16
lines changed

ext/LinearSolveBlockDiagonalsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b,
1818
end
1919
# Can't help but perform dynamic dispatch here
2020
return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, u, Pl, Pr, maxiters,
21-
abstol, reltol, verbose, assumptions; zeroinit)
21+
abstol, reltol, verbose, assumptions; zeroinit, blocksize = usize)
2222
end
2323

2424
end

ext/LinearSolveNNlibExt.jl

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,53 @@
11
module LinearSolveNNlibExt
22

3-
using LinearSolve, NNlib
3+
using LinearAlgebra, LinearSolve, NNlib
4+
import LinearSolve: SimpleGMRESCache, SimpleGMRES, OperatorAssumptions, _no_preconditioner,
5+
_init_cacheval, _norm2, LinearCache
6+
import UnPack: @unpack
7+
8+
function SciMLBase.solve!(cache::SimpleGMRESCache{true, T}, lincache::LinearCache) where {T}
9+
@unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol, blocksize = cache
10+
res_norm = β
11+
12+
# FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES
13+
# version
14+
for _ in 1:((maxiters ÷ M) + 1)
15+
for j in 1:M
16+
Qⱼ₊₁ = @view(Q[:, j + 1, :])
17+
mul!(vec(Qⱼ₊₁), A, vec(@view(Q[:, j, :]))) # Q(:,j+1) <- A Q(:, j)
18+
for i in 1:j
19+
H[i, j, :] .= vec(sum(@view(Q[:, i, :]) .* Qⱼ₊₁; dims = 1))
20+
Qⱼ₊₁ .-= H[i:i, j, :] .* @view(Q[:, i, :])
21+
end
22+
H[j + 1, j, :] .= vec(_norm2(Qⱼ₊₁, 1))
23+
Qⱼ₊₁ ./= H[j + 1, j:j, :]
24+
25+
# FIXME: Figure out a way to avoid the allocation
26+
# Using views doesn't work very well with LinearSolve
27+
y = similar(b, j, 1, size(H, 3))
28+
for bidx in 1:size(y, 3)
29+
y[:, :, bidx] .= @view(H[1:(j + 1), 1:j, bidx]) \ @view(βe₁[1:(j + 1), bidx])
30+
end
31+
32+
# Update the solution
33+
batched_mul!(reshape(x, blocksize, 1, :), @view(Q[:, 1:j, :]), y)
34+
mul!(r, A, x, T(-1), T(0))
35+
r .+= b
36+
res_norm = _norm2(reshape(r, blocksize, :), 1)
37+
38+
if maximum(res_norm) < abstol
39+
return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache;
40+
retcode = ReturnCode.Success)
41+
end
42+
end
43+
44+
# Restart
45+
Q[:, 1, :] = reshape(r, blocksize, :) ./ res_norm
46+
fill!(H, zero(T))
47+
end
48+
49+
return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache;
50+
retcode = ReturnCode.MaxIters)
51+
end
452

553
end

src/simplegmres.jl

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,15 @@ _no_preconditioner(::IdentityOperator) = true
5959
_no_preconditioner(::UniformScaling) = true
6060
_no_preconditioner(_) = false
6161

62-
function init_cacheval(alg::SimpleGMRES{false}, args...; kwargs...)
63-
return _init_cacheval(Val(false), alg, args...; kwargs...)
62+
_norm2(x) = norm(x, 2)
63+
_norm2(x, dims) = .√(sum(abs2, x; dims))
64+
65+
function init_cacheval(alg::SimpleGMRES{UDB}, args...; kwargs...) where {UDB}
66+
return _init_cacheval(Val(UDB), alg, args...; kwargs...)
6467
end
6568

66-
# TODO: We can check if `A` is a block diagonal matrix with uniformly sized square blocks
67-
# and use the specialized dispatch
6869
function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int,
69-
abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true)
70+
abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true, kwargs...)
7071
if zeroinit
7172
return SimpleGMRESCache{false}(0, 0, maxiters, alg.blocksize, zero(eltype(u)),
7273
similar(b, 0, 0), similar(b, 0, 0), u, similar(b, 0), similar(b, 0),
@@ -75,6 +76,7 @@ function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiter
7576

7677
@assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead."
7778
N = LinearAlgebra.checksquare(A)
79+
@assert N == length(b) "The size of `A` and `b` must match."
7880
T = eltype(u)
7981
M = min(maxiters, alg.restart)
8082
ϵ = eps(T)
@@ -87,7 +89,7 @@ function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiter
8789

8890
mul!(@view(Q[:, 1]), A, u, T(-1), T(0)) # r0 <- A u
8991
axpy!(T(1), b, @view(Q[:, 1])) # r0 <- r0 - b
90-
β = norm(@view(Q[:, 1]), 2)
92+
β = _norm2(@view(Q[:, 1]))
9193
Q[:, 1] ./= β
9294

9395
x = u
@@ -100,6 +102,45 @@ function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiter
100102
β, abstol)
101103
end
102104

105+
function _init_cacheval(::Val{true}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int,
106+
abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true,
107+
blocksize = alg.blocksize)
108+
if zeroinit
109+
return SimpleGMRESCache{true}(0, 0, maxiters, alg.blocksize, zero(eltype(u)),
110+
similar(b, 0, 0, 0), similar(b, 0, 0, 0), u, similar(b, 0), similar(b, 0, 0),
111+
A, b, similar(b, 0, 0), abstol)
112+
end
113+
114+
@assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead."
115+
N = LinearAlgebra.checksquare(A)
116+
@assert mod(N, blocksize)==0 "The blocksize must divide the size of the matrix."
117+
@assert N==length(b) "The size of `A` and `b` must match."
118+
T = eltype(u)
119+
M = min(maxiters, alg.restart)
120+
ϵ = eps(T)
121+
bsize = N ÷ blocksize
122+
123+
# Initialize the Cache
124+
## Use `b` since `A` might be an operator
125+
Q = similar(b, blocksize, M + 1, bsize)
126+
H = similar(b, M + 1, M, bsize)
127+
fill!(H, zero(T))
128+
129+
mul!(vec(@view(Q[:, 1, :])), A, u, T(-1), T(0)) # r0 <- A u
130+
axpy!(T(1), b, vec(@view(Q[:, 1, :]))) # r0 <- r0 - b
131+
β = _norm2(@view(Q[:, 1, :]), 1)
132+
Q[:, 1, :] ./= β
133+
134+
x = u
135+
r = similar(b)
136+
βe₁ = similar(b, M + 1, bsize)
137+
fill!(βe₁, 0)
138+
βe₁[1, :] .= vec(β) # Avoid the scalar indexing error
139+
140+
return SimpleGMRESCache{true}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, A, b,
141+
β, abstol)
142+
end
143+
103144
default_alias_A(::SimpleGMRES, ::Any, ::Any) = false
104145
default_alias_b(::SimpleGMRES, ::Any, ::Any) = false
105146

@@ -111,25 +152,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...)
111152
cache.cacheval = solver
112153
cache.isfresh = false
113154
end
114-
return SciMLBase.solve!(cache.cacheval)
155+
return SciMLBase.solve!(cache.cacheval, cache)
115156
end
116157

117-
function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T}
158+
function SciMLBase.solve!(cache::SimpleGMRESCache{false, T},
159+
lincache::LinearCache) where {T}
118160
@unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol = cache
119-
norm2 = Base.Fix2(norm, 2)
120161
res_norm = β
121162

122163
# FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES
123164
# version
124-
for _ in 1:(maxiters ÷ M + 1)
165+
for _ in 1:((maxiters ÷ M) + 1)
125166
for j in 1:M
126167
Qⱼ₊₁ = @view(Q[:, j + 1])
127168
mul!(Qⱼ₊₁, A, @view(Q[:, j])) # Q(:,j+1) <- A Q(:, j)
128169
for i in 1:j
129170
H[i, j] = dot(@view(Q[:, i]), Qⱼ₊₁)
130171
axpy!(-H[i, j], @view(Q[:, i]), Qⱼ₊₁)
131172
end
132-
H[j + 1, j] = norm2(Qⱼ₊₁)
173+
H[j + 1, j] = _norm2(Qⱼ₊₁)
133174
H[j + 1, j] > ϵ && (Qⱼ₊₁ ./= H[j + 1, j])
134175

135176
# FIXME: Figure out a way to avoid the allocation
@@ -140,10 +181,10 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T}
140181
mul!(x, @view(Q[:, 1:j]), y)
141182
mul!(r, A, x, T(-1), T(0))
142183
axpy!(T(1), b, r)
143-
res_norm = norm2(r)
184+
res_norm = _norm2(r)
144185

145186
if res_norm < abstol
146-
return SciMLBase.build_linear_solution(nothing, x, r, nothing;
187+
return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache;
147188
retcode = ReturnCode.Success)
148189
end
149190
end
@@ -153,6 +194,6 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T}
153194
fill!(H, zero(T))
154195
end
155196

156-
return SciMLBase.build_linear_solution(nothing, x, r, nothing;
197+
return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache;
157198
retcode = ReturnCode.MaxIters)
158199
end

0 commit comments

Comments
 (0)