Skip to content

Commit 72d38ee

Browse files
committed
Removes Mutable Iterator
1 parent 76fc7c0 commit 72d38ee

File tree

6 files changed

+103
-168
lines changed

6 files changed

+103
-168
lines changed

src/bisection.jl

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@ function Bisection(;exact_left=false, exact_right=false)
77
Bisection(exact_left, exact_right)
88
end
99

10-
mutable struct BisectionCache{uType}
11-
state::UInt8
10+
struct BisectionCache{uType}
11+
state::Int
1212
left::uType
1313
right::uType
1414
end
1515

1616
function alg_cache(alg::Bisection, left, right, p, ::Val{true})
17-
BisectionCache(UInt8(0), left, right)
17+
BisectionCache(0, left, right)
1818
end
1919

2020
function alg_cache(alg::Bisection, left, right, p, ::Val{false})
21-
BisectionCache(UInt8(0), left, right)
21+
BisectionCache(0, left, right)
2222
end
2323

24-
function perform_step!(solver, alg::Bisection, cache)
25-
@unpack f, p, left, right, fl, fr = solver
24+
function perform_step(solver::BracketingImmutableSolver, alg::Bisection, cache)
25+
@unpack f, p, left, right, fl, fr, cache = solver
2626

2727
if cache.state == 0
2828
fzero = zero(fl)
@@ -32,77 +32,83 @@ function perform_step!(solver, alg::Bisection, cache)
3232
mid = (left + right) / 2
3333

3434
if left == mid || right == mid
35-
solver.force_stop = true
36-
solver.retcode = :FloatingPointLimit
37-
return
35+
@set! solver.force_stop = true
36+
@set! solver.retcode = :FloatingPointLimit
37+
return solver
3838
end
3939

4040
fm = f(mid, p)
4141

4242
if iszero(fm)
4343
if alg.exact_left
44-
cache.state = 1
45-
cache.right = mid
46-
cache.left = mid
44+
@set! cache.state = 1
45+
@set! cache.right = mid
46+
@set! cache.left = mid
47+
@set! solver.cache = cache
4748
elseif alg.exact_right
48-
solver.left = prevfloat_tdir(mid, left, right)
49-
sync_residuals!(solver)
50-
cache.state = 2
51-
cache.left = mid
49+
@set! solver.left = prevfloat_tdir(mid, left, right)
50+
solver = sync_residuals!(solver)
51+
@set! cache.state = 2
52+
@set! cache.left = mid
53+
@set! solver.cache = cache
5254
else
53-
solver.left = prevfloat_tdir(mid, left, right)
54-
solver.right = nextfloat_tdir(mid, left, right)
55-
sync_residuals!(solver)
56-
solver.force_stop = true
57-
return
55+
@set! solver.left = prevfloat_tdir(mid, left, right)
56+
@set! solver.right = nextfloat_tdir(mid, left, right)
57+
solver = sync_residuals!(solver)
58+
@set! solver.force_stop = true
59+
return solver
5860
end
5961
else
6062
if sign(fm) == sign(fl)
61-
solver.left = mid
62-
solver.fl = fm
63+
@set! solver.left = mid
64+
@set! solver.fl = fm
6365
else
64-
solver.right = mid
65-
solver.fr = fm
66+
@set! solver.right = mid
67+
@set! solver.fr = fm
6668
end
6769
end
6870
elseif cache.state == 1
6971
mid = (left + cache.right) / 2
7072

7173
if cache.right == mid || left == mid
7274
if alg.exact_right
73-
cache.state = 2
74-
return
75+
@set! cache.state = 2
76+
@set! solver.cache = cache
77+
return solver
7578
else
76-
solver.right = nextfloat_tdir(mid, left, right)
77-
sync_residuals!(solver)
78-
solver.force_stop = true
79-
return
79+
@set! solver.right = nextfloat_tdir(mid, left, right)
80+
solver = sync_residuals!(solver)
81+
@set! solver.force_stop = true
82+
return solver
8083
end
8184
end
8285

8386
fm = f(mid, p)
8487

8588
if iszero(fm)
86-
cache.right = mid
89+
@set! cache.right = mid
90+
@set! solver.cache = cache
8791
else
88-
solver.left = mid
89-
solver.fl = fm
92+
@set! solver.left = mid
93+
@set! solver.fl = fm
9094
end
9195
else
9296
mid = (cache.left + right) / 2
9397

9498
if right == mid || cache.left == mid
95-
solver.force_stop = true
96-
return
99+
@set! solver.force_stop = true
100+
return solver
97101
end
98102

99103
fm = f(mid, p)
100104

101105
if iszero(fm)
102-
cache.left = mid
106+
@set! cache.left = mid
107+
@set! solver.cache = cache
103108
else
104-
solver.right = mid
105-
solver.fr = fm
109+
@set! solver.right = mid
110+
@set! solver.fr = fm
106111
end
107112
end
113+
solver
108114
end

src/falsi.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function alg_cache(alg::Falsi, left, right, p, ::Val{false})
99
nothing
1010
end
1111

12-
function perform_step!(solver, alg::Falsi, cache)
12+
function perform_step(solver, alg::Falsi, cache)
1313
@unpack f, p, left, right, fl, fr = solver
1414

1515
fzero = zero(fl)
@@ -19,27 +19,27 @@ function perform_step!(solver, alg::Falsi, cache)
1919
mid = (fr * left - fl * right) / (fr - fl)
2020

2121
if right == mid || right == mid
22-
solver.force_stop = true
23-
solver.retcode = :FloatingPointLimit
24-
return nothing
22+
@set! solver.force_stop = true
23+
@set! solver.retcode = :FloatingPointLimit
24+
return solver
2525
end
2626

2727
fm = f(mid, p)
2828

2929
if iszero(fm)
3030
# todo: phase 2 bisection similar to the raw method
31-
solver.force_stop = true
32-
solver.left = mid
33-
solver.fl = fm
34-
solver.retcode = :ExactSolutionAtLeft
31+
@set! solver.force_stop = true
32+
@set! solver.left = mid
33+
@set! solver.fl = fm
34+
@set! solver.retcode = :ExactSolutionAtLeft
3535
else
3636
if sign(fm) == sign(fl)
37-
solver.left = mid
38-
solver.fl = fm
37+
@set! solver.left = mid
38+
@set! solver.fl = fm
3939
else
40-
solver.right = mid
41-
solver.fr = fm
40+
@set! solver.right = mid
41+
@set! solver.fr = fm
4242
end
4343
end
44-
return nothing
44+
return solver
4545
end

src/raphson.jl

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
1515
jac_config::JC
1616
end
1717

18-
struct NewtonRaphsonConstantCache{ufType}
19-
uf::ufType
20-
end
21-
2218
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true})
2319
uf = JacobianWrapper(f,p)
2420
linsolve = alg.linsolve(Val{:init}, f, u)
@@ -39,34 +35,24 @@ function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true})
3935
end
4036

4137
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{false})
42-
uf = JacobianWrapper(f,p)
43-
NewtonRaphsonConstantCache(uf)
44-
end
45-
46-
function perform_step!(solver, alg::NewtonRaphson, cache::NewtonRaphsonConstantCache)
47-
@unpack u, fu, f, p = solver
48-
J = calc_J(solver, cache)
49-
solver.u = u - J \ fu
50-
solver.fu = f(solver.u, p)
51-
if iszero(solver.fu) || solver.internalnorm(solver.fu) < solver.tol
52-
solver.force_stop = true
53-
end
38+
nothing
5439
end
5540

56-
function perform_step!(solver, alg::NewtonRaphson, cache::NewtonRaphsonCache)
57-
@unpack u, fu, f, p = solver
41+
function perform_step(solver::NewtonImmutableSolver, alg::NewtonRaphson, ::Val{true})
42+
@unpack u, fu, f, p, cache = solver
5843
@unpack J, linsolve, du1 = cache
5944
calc_J!(J, solver, cache)
6045
# u = u - J \ fu
6146
linsolve(du1, J, fu, true)
6247
@. u = u - du1
6348
f(fu, u, p)
6449
if solver.internalnorm(solver.fu) < solver.tol
65-
solver.force_stop = true
50+
@set! solver.force_stop = true
6651
end
52+
return solver
6753
end
6854

69-
function perform_step(solver, alg::NewtonRaphson)
55+
function perform_step(solver::NewtonImmutableSolver, alg::NewtonRaphson, ::Val{false})
7056
@unpack u, fu, f, p = solver
7157
J = calc_J(solver, ImmutableJacobianWrapper(f, p))
7258
@set! solver.u = u - J \ fu

src/solve.jl

Lines changed: 14 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ end
99
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
1010
alias_u0 = false,
1111
maxiters = 1000,
12-
# bracketing algorithms only solve scalar problems
13-
immutable = (eltype(prob.u0) <: Number),
1412
kwargs...
1513
) where {uType, iip}
1614

@@ -31,19 +29,13 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracke
3129
p = prob.p
3230
fl = f(left, p)
3331
fr = f(right, p)
34-
35-
if immutable
36-
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, :Default)
37-
else
38-
cache = alg_cache(alg, left, right, p, Val(iip))
39-
return BracketingSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default)
40-
end
32+
cache = alg_cache(alg, left, right,p, Val(iip))
33+
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, :Default, cache, iip)
4134
end
4235

4336
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
4437
alias_u0 = false,
4538
maxiters = 1000,
46-
immutable = (prob.u0 isa StaticArray || prob.u0 isa Number),
4739
tol = 1e-6,
4840
internalnorm = Base.Fix2(DiffEqBase.ODE_DEFAULT_NORM, nothing),
4941
kwargs...
@@ -62,32 +54,14 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewton
6254
else
6355
fu = f(u, p)
6456
end
65-
66-
if immutable
67-
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, :Default, tol)
68-
else
69-
cache = alg_cache(alg, f, u, p, Val(iip))
70-
return NewtonSolver(1, f, alg, u, fu, p, cache, false, maxiters, internalnorm, :Default, tol)
71-
end
72-
end
73-
74-
function DiffEqBase.solve!(solver::AbstractNonlinearSolver)
75-
mic_check!(solver)
76-
while !solver.force_stop && solver.iter < solver.maxiters
77-
perform_step!(solver, solver.alg, solver.cache)
78-
solver.iter += 1
79-
end
80-
if solver.iter == solver.maxiters
81-
solver.retcode = :MaxitersExceeded
82-
end
83-
sol = get_solution(solver)
84-
return sol
57+
cache = alg_cache(alg, f, u, p, Val(iip))
58+
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, :Default, tol, cache, iip)
8559
end
8660

8761
function DiffEqBase.solve!(solver::AbstractImmutableNonlinearSolver)
8862
solver = mic_check(solver)
8963
while !solver.force_stop && solver.iter < solver.maxiters
90-
solver = perform_step(solver, solver.alg)
64+
solver = perform_step(solver, solver.alg, Val(solver.iip))
9165
@set! solver.iter += 1
9266
end
9367
if solver.iter == solver.maxiters
@@ -103,21 +77,6 @@ end
10377
10478
Checks before running main solving iterations.
10579
"""
106-
function mic_check!(solver::BracketingSolver)
107-
@unpack f, fl, fr = solver
108-
flr = fl * fr
109-
fzero = zero(flr)
110-
(flr > fzero) && error("Non bracketing interval passed in bracketing method.")
111-
if fl == fzero
112-
solver.force_stop = true
113-
solver.retcode = :ExactSolutionAtLeft
114-
elseif fr == fzero
115-
solver.force_stop = true
116-
solver.retcode = :ExactSolutionAtRight
117-
end
118-
nothing
119-
end
120-
12180
function mic_check(solver::BracketingImmutableSolver)
12281
@unpack f, fl, fr = solver
12382
flr = fl * fr
@@ -133,10 +92,6 @@ function mic_check(solver::BracketingImmutableSolver)
13392
solver
13493
end
13594

136-
function mic_check!(solver::NewtonSolver)
137-
nothing
138-
end
139-
14095
function mic_check(solver::NewtonImmutableSolver)
14196
solver
14297
end
@@ -147,11 +102,11 @@ end
147102
148103
Form solution object from solver types
149104
"""
150-
function get_solution(solver::Union{BracketingImmutableSolver, BracketingSolver})
105+
function get_solution(solver::BracketingImmutableSolver)
151106
return BracketingSolution(solver.left, solver.right, solver.retcode)
152107
end
153108

154-
function get_solution(solver::Union{NewtonImmutableSolver, NewtonSolver})
109+
function get_solution(solver::NewtonImmutableSolver)
155110
return NewtonSolution(solver.u, solver.retcode)
156111
end
157112

@@ -160,16 +115,16 @@ end
160115
161116
Reinitialize solver to the original starting conditions
162117
"""
163-
function reinit!(solver::NewtonSolver, prob::NonlinearProblem{uType, true}) where {uType}
118+
function reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, true}) where {uType}
164119
@. solver.u = prob.u0
165-
solver.iter = 1
166-
solver.force_stop = false
120+
@set! solver.iter = 1
121+
@set! solver.force_stop = false
167122
return solver
168123
end
169124

170-
function reinit!(solver::NewtonSolver, prob::NonlinearProblem{uType, false}) where {uType}
171-
solver.u = prob.u0
172-
solver.iter = 1
173-
solver.force_stop = false
125+
function reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, false}) where {uType}
126+
@set! solver.u = prob.u0
127+
@set! solver.iter = 1
128+
@set! solver.force_stop = false
174129
return solver
175130
end

0 commit comments

Comments
 (0)