Skip to content

Commit 15a612a

Browse files
Merge pull request #2 from JuliaComputing/refactor
Refactor
2 parents 23835fc + b845c57 commit 15a612a

File tree

9 files changed

+214
-242
lines changed

9 files changed

+214
-242
lines changed

src/bisection.jl

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@ function Bisection(;exact_left=false, exact_right=false)
77
Bisection(exact_left, exact_right)
88
end
99

10-
mutable struct BisectionCache{uType}
11-
state::UInt8
10+
struct BisectionCache{uType}
11+
state::Int
1212
left::uType
1313
right::uType
1414
end
1515

1616
function alg_cache(alg::Bisection, left, right, p, ::Val{true})
17-
BisectionCache(UInt8(0), left, right)
17+
BisectionCache(0, left, right)
1818
end
1919

2020
function alg_cache(alg::Bisection, left, right, p, ::Val{false})
21-
BisectionCache(UInt8(0), left, right)
21+
BisectionCache(0, left, right)
2222
end
2323

24-
function perform_step!(solver, alg::Bisection, cache)
25-
@unpack f, p, left, right, fl, fr = solver
24+
function perform_step(solver::BracketingImmutableSolver, alg::Bisection, cache)
25+
@unpack f, p, left, right, fl, fr, cache = solver
2626

2727
if cache.state == 0
2828
fzero = zero(fl)
@@ -32,77 +32,83 @@ function perform_step!(solver, alg::Bisection, cache)
3232
mid = (left + right) / 2
3333

3434
if left == mid || right == mid
35-
solver.force_stop = true
36-
solver.retcode = :FloatingPointLimit
37-
return
35+
@set! solver.force_stop = true
36+
@set! solver.retcode = FLOATING_POINT_LIMIT
37+
return solver
3838
end
3939

4040
fm = f(mid, p)
4141

4242
if iszero(fm)
4343
if alg.exact_left
44-
cache.state = 1
45-
cache.right = mid
46-
cache.left = mid
44+
@set! cache.state = 1
45+
@set! cache.right = mid
46+
@set! cache.left = mid
47+
@set! solver.cache = cache
4748
elseif alg.exact_right
48-
solver.left = prevfloat_tdir(mid, left, right)
49-
sync_residuals!(solver)
50-
cache.state = 2
51-
cache.left = mid
49+
@set! solver.left = prevfloat_tdir(mid, left, right)
50+
solver = sync_residuals!(solver)
51+
@set! cache.state = 2
52+
@set! cache.left = mid
53+
@set! solver.cache = cache
5254
else
53-
solver.left = prevfloat_tdir(mid, left, right)
54-
solver.right = nextfloat_tdir(mid, left, right)
55-
sync_residuals!(solver)
56-
solver.force_stop = true
57-
return
55+
@set! solver.left = prevfloat_tdir(mid, left, right)
56+
@set! solver.right = nextfloat_tdir(mid, left, right)
57+
solver = sync_residuals!(solver)
58+
@set! solver.force_stop = true
59+
return solver
5860
end
5961
else
6062
if sign(fm) == sign(fl)
61-
solver.left = mid
62-
solver.fl = fm
63+
@set! solver.left = mid
64+
@set! solver.fl = fm
6365
else
64-
solver.right = mid
65-
solver.fr = fm
66+
@set! solver.right = mid
67+
@set! solver.fr = fm
6668
end
6769
end
6870
elseif cache.state == 1
6971
mid = (left + cache.right) / 2
7072

7173
if cache.right == mid || left == mid
7274
if alg.exact_right
73-
cache.state = 2
74-
return
75+
@set! cache.state = 2
76+
@set! solver.cache = cache
77+
return solver
7578
else
76-
solver.right = nextfloat_tdir(mid, left, right)
77-
sync_residuals!(solver)
78-
solver.force_stop = true
79-
return
79+
@set! solver.right = nextfloat_tdir(mid, left, right)
80+
solver = sync_residuals!(solver)
81+
@set! solver.force_stop = true
82+
return solver
8083
end
8184
end
8285

8386
fm = f(mid, p)
8487

8588
if iszero(fm)
86-
cache.right = mid
89+
@set! cache.right = mid
90+
@set! solver.cache = cache
8791
else
88-
solver.left = mid
89-
solver.fl = fm
92+
@set! solver.left = mid
93+
@set! solver.fl = fm
9094
end
9195
else
9296
mid = (cache.left + right) / 2
9397

9498
if right == mid || cache.left == mid
95-
solver.force_stop = true
96-
return
99+
@set! solver.force_stop = true
100+
return solver
97101
end
98102

99103
fm = f(mid, p)
100104

101105
if iszero(fm)
102-
cache.left = mid
106+
@set! cache.left = mid
107+
@set! solver.cache = cache
103108
else
104-
solver.right = mid
105-
solver.fr = fm
109+
@set! solver.right = mid
110+
@set! solver.fr = fm
106111
end
107112
end
113+
solver
108114
end

src/falsi.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function alg_cache(alg::Falsi, left, right, p, ::Val{false})
99
nothing
1010
end
1111

12-
function perform_step!(solver, alg::Falsi, cache)
12+
function perform_step(solver, alg::Falsi, cache)
1313
@unpack f, p, left, right, fl, fr = solver
1414

1515
fzero = zero(fl)
@@ -19,27 +19,27 @@ function perform_step!(solver, alg::Falsi, cache)
1919
mid = (fr * left - fl * right) / (fr - fl)
2020

2121
if right == mid || right == mid
22-
solver.force_stop = true
23-
solver.retcode = :FloatingPointLimit
24-
return nothing
22+
@set! solver.force_stop = true
23+
@set! solver.retcode = FLOATING_POINT_LIMIT
24+
return solver
2525
end
2626

2727
fm = f(mid, p)
2828

2929
if iszero(fm)
3030
# todo: phase 2 bisection similar to the raw method
31-
solver.force_stop = true
32-
solver.left = mid
33-
solver.fl = fm
34-
solver.retcode = :ExactSolutionAtLeft
31+
@set! solver.force_stop = true
32+
@set! solver.left = mid
33+
@set! solver.fl = fm
34+
@set! solver.retcode = EXACT_SOLUTION_LEFT
3535
else
3636
if sign(fm) == sign(fl)
37-
solver.left = mid
38-
solver.fl = fm
37+
@set! solver.left = mid
38+
@set! solver.fl = fm
3939
else
40-
solver.right = mid
41-
solver.fr = fm
40+
@set! solver.right = mid
41+
@set! solver.fr = fm
4242
end
4343
end
44-
return nothing
44+
return solver
4545
end

src/jacobian.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ end
66
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
77
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
88

9-
mutable struct ImmutableJacobianWrapper{fType, pType}
9+
struct ImmutableJacobianWrapper{fType, pType}
1010
f::fType
1111
p::pType
1212
end
@@ -24,15 +24,13 @@ end
2424

2525
function calc_J(solver, uf::ImmutableJacobianWrapper)
2626
@unpack u, f, p, alg = solver
27-
@set! uf.f = f
28-
@set! uf.p = p
2927
J = jacobian(uf, u, solver)
3028
return J
3129
end
3230

3331
function jacobian(f, x, solver)
3432
if alg_autodiff(solver.alg)
35-
J = ForwardDiff.jacobian(f, Ref(x)[])
33+
J = ForwardDiff.jacobian(f, x)
3634
else
3735
J = FiniteDiff.finite_difference_derivative(f, x, solver.alg.diff_type, eltype(x))
3836
end

src/raphson.jl

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
1515
jac_config::JC
1616
end
1717

18-
struct NewtonRaphsonConstantCache{ufType}
19-
uf::ufType
20-
end
21-
2218
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true})
2319
uf = JacobianWrapper(f,p)
2420
linsolve = alg.linsolve(Val{:init}, f, u)
@@ -39,34 +35,24 @@ function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true})
3935
end
4036

4137
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{false})
42-
uf = JacobianWrapper(f,p)
43-
NewtonRaphsonConstantCache(uf)
44-
end
45-
46-
function perform_step!(solver, alg::NewtonRaphson, cache::NewtonRaphsonConstantCache)
47-
@unpack u, fu, f, p = solver
48-
J = calc_J(solver, cache)
49-
solver.u = u - J \ fu
50-
solver.fu = f(solver.u, p)
51-
if iszero(solver.fu) || solver.internalnorm(solver.fu) < solver.tol
52-
solver.force_stop = true
53-
end
38+
nothing
5439
end
5540

56-
function perform_step!(solver, alg::NewtonRaphson, cache::NewtonRaphsonCache)
57-
@unpack u, fu, f, p = solver
41+
function perform_step(solver::NewtonImmutableSolver, alg::NewtonRaphson, ::Val{true})
42+
@unpack u, fu, f, p, cache = solver
5843
@unpack J, linsolve, du1 = cache
5944
calc_J!(J, solver, cache)
6045
# u = u - J \ fu
6146
linsolve(du1, J, fu, true)
6247
@. u = u - du1
6348
f(fu, u, p)
6449
if solver.internalnorm(solver.fu) < solver.tol
65-
solver.force_stop = true
50+
@set! solver.force_stop = true
6651
end
52+
return solver
6753
end
6854

69-
function perform_step(solver, alg::NewtonRaphson)
55+
function perform_step(solver::NewtonImmutableSolver, alg::NewtonRaphson, ::Val{false})
7056
@unpack u, fu, f, p = solver
7157
J = calc_J(solver, ImmutableJacobianWrapper(f, p))
7258
@set! solver.u = u - J \ fu

src/scalar.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ function DiffEqBase.solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, arg
88
xo = oftype(x, Inf)
99
for i in 1:maxiters
1010
fx, dfx = value_derivative(f, x)
11-
iszero(fx) && return NewtonSolution(x, :Default)
11+
iszero(fx) && return NewtonSolution(x, DEFAULT)
1212
Δx = dfx \ fx
1313
x -= Δx
1414
if isapprox(x, xo, atol=atol, rtol=rtol)
15-
return NewtonSolution(x, :Default)
15+
return NewtonSolution(x, DEFAULT)
1616
end
1717
xo = x
1818
end
19-
return NewtonSolution(x, :MaxitersExceeded)
19+
return NewtonSolution(x, MAXITERS_EXCEED)
2020
end
2121

2222
function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
@@ -25,14 +25,14 @@ function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
2525
fl, fr = f(left), f(right)
2626

2727
if iszero(fl)
28-
return fl
28+
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT)
2929
end
3030

3131
i = 1
3232
if !iszero(fr)
3333
while i < maxiters
3434
mid = (left + right) / 2
35-
(mid == left || mid == right) && return left
35+
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
3636
fm = f(mid)
3737
if iszero(fm)
3838
right = mid
@@ -51,7 +51,7 @@ function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
5151

5252
while i < maxiters
5353
mid = (left + right) / 2
54-
(mid == left || mid == right) && return left
54+
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
5555
fm = f(mid)
5656
if iszero(fm)
5757
right = mid
@@ -63,5 +63,5 @@ function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
6363
i += 1
6464
end
6565

66-
return left
66+
return BracketingSolution(left, right, MAXITERS_EXCEED)
6767
end

0 commit comments

Comments
 (0)