Skip to content

Commit 3d82d71

Browse files
committed
initial commit
1 parent e91e4ac commit 3d82d71

File tree

4 files changed

+43
-11
lines changed

4 files changed

+43
-11
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
89
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
910
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1011
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"

src/LinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using SparseArrays
88
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
99
using Setfield
1010
using UnPack
11+
using FastBroadcast
1112

1213
# wrap
1314
import Krylov

src/common.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ function set_cacheval(cache::LinearCache, alg_cache)
4444
return cache
4545
end
4646

47+
function set_prec(cache, Pl, Pr)
48+
@set! cache.Pl = Pl
49+
@set! cache.Pr = Pr
50+
return cache
51+
end
52+
4753
init_cacheval(alg::Union{SciMLLinearSolveAlgorithm,Nothing}, A, b, u) = nothing
4854

4955
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)

src/wrappers.jl

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
11

2-
#TODO: composed preconditioners, preconditioner setter for cache,
3-
# detailed tests for wrappers
4-
52
## Preconditioners
63

74
struct ScaleVector{T}
85
s::T
96
isleft::Bool
107
end
118

12-
function LinearAlgebra.ldiv!(v::ScaleVector, x)
9+
function LinearAlgebra.ldiv!(P::ScaleVector, x)
10+
P.s == one(eltype(P.s)) && return x
11+
12+
if P.isleft
13+
@. x = x * P.s # @.. doesnt speed up scalar computation
14+
else
15+
@. x = x / P.s
16+
end
1317
end
1418

15-
function LinearAlgebra.ldiv!(y, v::ScaleVector, x)
19+
function LinearAlgebra.ldiv!(y, P::ScaleVector, x)
20+
P.s == one(eltype(P.s)) && return y = x
21+
22+
if P.isleft
23+
@. y = x / P.s
24+
else
25+
@. y = x * P.s
26+
end
1627
end
1728

1829
struct ComposePreconditioner{Ti,To}
@@ -21,12 +32,19 @@ struct ComposePreconditioner{Ti,To}
2132
isleft::Bool
2233
end
2334

24-
function LinearAlgebra.ldiv!(v::ComposePreconditioner, x)
25-
@unpack inner, outer, isleft = v
35+
function LinearAlgebra.ldiv!(P::ComposePreconditioner, x)
36+
@unpack inner, outer, isleft = P
37+
if isleft
38+
ldiv!(outer, x)
39+
ldiv!(inner, x)
40+
else
41+
ldiv!(inner, x)
42+
ldiv!(outer, x)
43+
end
2644
end
2745

28-
function LinearAlgebra.ldiv!(y, v::ComposePreconditioner, x)
29-
@unpack inner, outer, isleft = v
46+
function LinearAlgebra.ldiv!(y, P::ComposePreconditioner, x)
47+
@unpack inner, outer, isleft = P
3048
end
3149

3250
## Krylov.jl
@@ -132,6 +150,9 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
132150
cache = set_cacheval(cache, solver)
133151
end
134152

153+
# Pl = ComposePreconditioner(alg.Pl, cache.Pl, true)
154+
# Pr = ComposePreconditioner(alg.Pr, cache.Pr, false)
155+
135156
atol = cache.abstol
136157
rtol = cache.reltol
137158
itmax = cache.maxiters
@@ -204,8 +225,11 @@ IterativeSolversJL_MINRES(args...;kwargs...) =
204225
function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
205226
@unpack A, b, u = cache
206227

207-
Pl = (alg.Pl == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pl
208-
Pr = (alg.Pr == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pr
228+
Pl = alg.Pl
229+
Pr = alg.Pr
230+
231+
# Pl = ComposePreconditioner(alg.Pl, cache.Pl, true)
232+
# Pr = ComposePreconditioner(alg.Pr, cache.Pr, false)
209233

210234
abstol = cache.abstol
211235
reltol = cache.reltol

0 commit comments

Comments
 (0)