Skip to content

Commit 76fc7c0

Browse files
committed
(refactor) General Cleaning and Tests
1 parent 23835fc commit 76fc7c0

File tree

6 files changed

+114
-85
lines changed

6 files changed

+114
-85
lines changed

src/jacobian.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ end
66
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
77
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
88

9-
mutable struct ImmutableJacobianWrapper{fType, pType}
9+
struct ImmutableJacobianWrapper{fType, pType}
1010
f::fType
1111
p::pType
1212
end
@@ -24,15 +24,13 @@ end
2424

2525
function calc_J(solver, uf::ImmutableJacobianWrapper)
2626
@unpack u, f, p, alg = solver
27-
@set! uf.f = f
28-
@set! uf.p = p
2927
J = jacobian(uf, u, solver)
3028
return J
3129
end
3230

3331
function jacobian(f, x, solver)
3432
if alg_autodiff(solver.alg)
35-
J = ForwardDiff.jacobian(f, Ref(x)[])
33+
J = ForwardDiff.jacobian(f, x)
3634
else
3735
J = FiniteDiff.finite_difference_derivative(f, x, solver.alg.diff_type, eltype(x))
3836
end

src/scalar.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
2525
fl, fr = f(left), f(right)
2626

2727
if iszero(fl)
28-
return fl
28+
return BracketingSolution(left, right, :ExactSolutionAtLeft)
2929
end
3030

3131
i = 1
3232
if !iszero(fr)
3333
while i < maxiters
3434
mid = (left + right) / 2
35-
(mid == left || mid == right) && return left
35+
(mid == left || mid == right) && return BracketingSolution(left, right, :FloatingPointLimit)
3636
fm = f(mid)
3737
if iszero(fm)
3838
right = mid
@@ -51,7 +51,7 @@ function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
5151

5252
while i < maxiters
5353
mid = (left + right) / 2
54-
(mid == left || mid == right) && return left
54+
(mid == left || mid == right) && return BracketingSolution(left, right, :FloatingPointLimit)
5555
fm = f(mid)
5656
if iszero(fm)
5757
right = mid
@@ -63,5 +63,5 @@ function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
6363
i += 1
6464
end
6565

66-
return left
66+
return BracketingSolution(left, right, :MaxitersExceeded)
6767
end

src/solve.jl

Lines changed: 33 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracke
1010
alias_u0 = false,
1111
maxiters = 1000,
1212
# bracketing algorithms only solve scalar problems
13-
immutable = (prob.u0 isa StaticArray || prob.u0 isa Number),
13+
immutable = (eltype(prob.u0) <: Number),
1414
kwargs...
1515
) where {uType, iip}
1616

1717
if !(prob.u0 isa Tuple)
1818
error("You need to pass a tuple of u0 in bracketing algorithms.")
1919
end
2020

21+
if eltype(prob.u0) isa AbstractArray
22+
error("Bracketing Algorithms work for scalar arguments only")
23+
end
24+
2125
if alias_u0
2226
left, right = prob.u0
2327
else
@@ -28,13 +32,11 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracke
2832
fl = f(left, p)
2933
fr = f(right, p)
3034

31-
cache = alg_cache(alg, left, right, p, Val(iip))
32-
33-
sol = build_solution(left, Val(iip))
3435
if immutable
35-
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default, sol)
36+
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, :Default)
3637
else
37-
return BracketingSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default, sol)
38+
cache = alg_cache(alg, left, right, p, Val(iip))
39+
return BracketingSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default)
3840
end
3941
end
4042

@@ -61,38 +63,32 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewton
6163
fu = f(u, p)
6264
end
6365

64-
65-
sol = build_newton_solution(u, Val(iip))
6666
if immutable
67-
return NewtonImmutableSolver(1, f, alg, u, fu, p, nothing, false, maxiters, internalnorm, :Default, tol, sol)
67+
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, :Default, tol)
6868
else
6969
cache = alg_cache(alg, f, u, p, Val(iip))
70-
return NewtonSolver(1, f, alg, u, fu, p, cache, false, maxiters, internalnorm, :Default, tol, sol)
70+
return NewtonSolver(1, f, alg, u, fu, p, cache, false, maxiters, internalnorm, :Default, tol)
7171
end
7272
end
7373

7474
function DiffEqBase.solve!(solver::AbstractNonlinearSolver)
75-
# sync_residuals!(solver)
7675
mic_check!(solver)
7776
while !solver.force_stop && solver.iter < solver.maxiters
7877
perform_step!(solver, solver.alg, solver.cache)
7978
solver.iter += 1
80-
# sync_residuals!(solver)
8179
end
8280
if solver.iter == solver.maxiters
8381
solver.retcode = :MaxitersExceeded
8482
end
85-
solver = set_solution(solver)
86-
return solver.sol
83+
sol = get_solution(solver)
84+
return sol
8785
end
8886

8987
function DiffEqBase.solve!(solver::AbstractImmutableNonlinearSolver)
90-
# sync_residuals!(solver)
9188
solver = mic_check(solver)
9289
while !solver.force_stop && solver.iter < solver.maxiters
9390
solver = perform_step(solver, solver.alg)
9491
@set! solver.iter += 1
95-
# sync_residuals!(solver)
9692
end
9793
if solver.iter == solver.maxiters
9894
@set! solver.retcode = :MaxitersExceeded
@@ -101,6 +97,12 @@ function DiffEqBase.solve!(solver::AbstractImmutableNonlinearSolver)
10197
return sol
10298
end
10399

100+
"""
101+
mic_check(solver::AbstractImmutableNonlinearSolver)
102+
mic_check!(solver::AbstractNonlinearSolver)
103+
104+
Checks before running main solving iterations.
105+
"""
104106
function mic_check!(solver::BracketingSolver)
105107
@unpack f, fl, fr = solver
106108
flr = fl * fr
@@ -139,52 +141,35 @@ function mic_check(solver::NewtonImmutableSolver)
139141
solver
140142
end
141143

142-
function check_for_exact_solution!(solver::BracketingSolver)
143-
@unpack fl, fr = solver
144-
fzero = zero(fl)
145-
if fl == fzero
146-
solver.retcode = :ExactSolutionAtLeft
147-
return true
148-
elseif fr == fzero
149-
solver.retcode = :ExactSolutionAtRight
150-
return true
151-
end
152-
return false
153-
end
154-
155-
function set_solution(solver::BracketingSolver)
156-
sol = solver.sol
157-
@set! sol.left = solver.left
158-
@set! sol.right = solver.right
159-
@set! sol.retcode = solver.retcode
160-
@set! solver.sol = sol
161-
return solver
162-
end
144+
"""
145+
get_solution(solver::Union{BracketingImmutableSolver, BracketingSolver})
146+
get_solution(solver::Union{NewtonImmutableSolver, NewtonSolver})
163147
164-
function get_solution(solver::BracketingImmutableSolver)
165-
return (left = solver.left, right = solver.right, retcode = solver.retcode)
148+
Form solution object from solver types
149+
"""
150+
function get_solution(solver::Union{BracketingImmutableSolver, BracketingSolver})
151+
return BracketingSolution(solver.left, solver.right, solver.retcode)
166152
end
167153

168-
function set_solution(solver::NewtonSolver)
169-
sol = solver.sol
170-
@set! sol.u = solver.u
171-
@set! sol.retcode = solver.retcode
172-
@set! solver.sol = sol
173-
return solver
154+
function get_solution(solver::Union{NewtonImmutableSolver, NewtonSolver})
155+
return NewtonSolution(solver.u, solver.retcode)
174156
end
175157

176-
function get_solution(solver::NewtonImmutableSolver)
177-
return (u = solver.u, retcode = solver.retcode)
178-
end
158+
"""
159+
reinit!(solver, prob)
179160
161+
Reinitialize solver to the original starting conditions
162+
"""
180163
function reinit!(solver::NewtonSolver, prob::NonlinearProblem{uType, true}) where {uType}
181164
@. solver.u = prob.u0
182165
solver.iter = 1
183166
solver.force_stop = false
167+
return solver
184168
end
185169

186170
function reinit!(solver::NewtonSolver, prob::NonlinearProblem{uType, false}) where {uType}
187171
solver.u = prob.u0
188172
solver.iter = 1
189173
solver.force_stop = false
174+
return solver
190175
end

src/types.jl

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mutable struct BracketingSolver{fType, algType, uType, resType, pType, cacheType, solType} <: AbstractNonlinearSolver
1+
mutable struct BracketingSolver{fType, algType, uType, resType, pType, cacheType} <: AbstractNonlinearSolver
22
iter::Int
33
f::fType
44
alg::algType
@@ -11,10 +11,9 @@ mutable struct BracketingSolver{fType, algType, uType, resType, pType, cacheType
1111
force_stop::Bool
1212
maxiters::Int
1313
retcode::Symbol
14-
sol::solType
1514
end
1615

17-
struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, solType} <: AbstractImmutableNonlinearSolver
16+
struct BracketingImmutableSolver{fType, algType, uType, resType, pType} <: AbstractImmutableNonlinearSolver
1817
iter::Int
1918
f::fType
2019
alg::algType
@@ -23,15 +22,13 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp
2322
fl::resType
2423
fr::resType
2524
p::pType
26-
cache::cacheType
2725
force_stop::Bool
2826
maxiters::Int
2927
retcode::Symbol
30-
sol::solType
3128
end
3229

3330

34-
mutable struct NewtonSolver{fType, algType, uType, resType, pType, cacheType, INType, tolType, solType} <: AbstractNonlinearSolver
31+
mutable struct NewtonSolver{fType, algType, uType, resType, pType, cacheType, INType, tolType} <: AbstractNonlinearSolver
3532
iter::Int
3633
f::fType
3734
alg::algType
@@ -44,51 +41,35 @@ mutable struct NewtonSolver{fType, algType, uType, resType, pType, cacheType, IN
4441
internalnorm::INType
4542
retcode::Symbol
4643
tol::tolType
47-
sol::solType
4844
end
4945

50-
struct NewtonImmutableSolver{fType, algType, uType, resType, pType, cacheType, INType, tolType, solType} <: AbstractImmutableNonlinearSolver
46+
struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType} <: AbstractImmutableNonlinearSolver
5147
iter::Int
5248
f::fType
5349
alg::algType
5450
u::uType
5551
fu::resType
5652
p::pType
57-
cache::cacheType
5853
force_stop::Bool
5954
maxiters::Int
6055
internalnorm::INType
6156
retcode::Symbol
6257
tol::tolType
63-
sol::solType
6458
end
6559

66-
function sync_residuals!(solver::BracketingSolver)
67-
solver.fl = solver.f(solver.left, solver.p)
68-
solver.fr = solver.f(solver.right, solver.p)
69-
nothing
70-
end
71-
72-
mutable struct BracketingSolution{uType}
60+
struct BracketingSolution{uType}
7361
left::uType
7462
right::uType
7563
retcode::Symbol
7664
end
7765

78-
function build_solution(u_prototype, ::Val{true})
79-
return BracketingSolution(similar(u_prototype), similar(u_prototype), :Default)
80-
end
81-
82-
function build_solution(u_prototype, ::Val{false})
83-
return BracketingSolution(zero(u_prototype), zero(u_prototype), :Default)
84-
end
85-
8666
struct NewtonSolution{uType}
8767
u::uType
8868
retcode::Symbol
8969
end
9070

91-
function build_newton_solution(u_prototype, ::Val{iip}) where iip
92-
return NewtonSolution(zero(u_prototype), :Default)
71+
function sync_residuals!(solver::BracketingSolver)
72+
solver.fl = solver.f(solver.left, solver.p)
73+
solver.fr = solver.f(solver.right, solver.p)
74+
nothing
9375
end
94-

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ function value_derivative(f::F, x::R) where {F,R}
2626
end
2727

2828
DiffEqBase.has_Wfact(f::Function) = false
29-
DiffEqBase.has_Wfact_t(f::Function) = false
29+
DiffEqBase.has_Wfact_t(f::Function) = false

test/runtests.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,68 @@ for p in 1.1:0.1:100.0
6868
@test g(p) sqrt(p)
6969
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
7070
end
71+
72+
# Error Checks
73+
74+
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
75+
probN = NonlinearProblem(f, u0)
76+
77+
@test solve(probN, NewtonRaphson()).u[end] sqrt(2.0)
78+
@test solve(probN, NewtonRaphson(); immutable = false).u[end] sqrt(2.0)
79+
@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] sqrt(2.0)
80+
@test solve(probN, NewtonRaphson(;autodiff=false); immutable = false).u[end] sqrt(2.0)
81+
82+
f, u0 = (u, p) -> u .* u .- 2.0, 1.0
83+
probN = NonlinearProblem(f, u0)
84+
85+
@test solve(probN, NewtonRaphson()).u sqrt(2.0)
86+
@test solve(probN, NewtonRaphson(); immutable = false).u sqrt(2.0)
87+
@test solve(probN, NewtonRaphson(;autodiff=false)).u sqrt(2.0)
88+
@test solve(probN, NewtonRaphson(;autodiff=false); immutable = false).u sqrt(2.0)
89+
90+
91+
# Bisection Tests
92+
f, u0 = (u, p) -> u .* u .- 2.0, (1.0, 2.0)
93+
probB = NonlinearProblem(f, u0)
94+
95+
# this should call the fast scalar overload
96+
@test solve(probB, Bisection()).left sqrt(2.0)
97+
98+
# these should call the iterator version
99+
solver = init(probB, Bisection())
100+
@test solver isa NonlinearSolve.BracketingImmutableSolver
101+
# Question: Do we need BracketingImmutableSolver? We have fast scalar overload and
102+
# Bracketing solvers work only for scalars.
103+
104+
solver = init(probB, Bisection(); immutable = false)
105+
@test solver isa NonlinearSolve.BracketingSolver
106+
@test solve!(solver).left sqrt(2.0)
107+
108+
# Garuntee Tests for Bisection
109+
f = function (u, p)
110+
if u < 2.0
111+
return u - 2.0
112+
elseif u > 3.0
113+
return u - 3.0
114+
else
115+
return 0.0
116+
end
117+
end
118+
probB = NonlinearProblem(f, (0.0, 4.0))
119+
120+
solver = init(probB, Bisection(;exact_left = true); immutable = false)
121+
sol = solve!(solver)
122+
@test f(sol.left, nothing) < 0.0
123+
@test f(nextfloat(sol.left), nothing) >= 0.0
124+
125+
solver = init(probB, Bisection(;exact_right = true); immutable = false)
126+
sol = solve!(solver)
127+
@test f(sol.right, nothing) > 0.0
128+
@test f(prevfloat(sol.right), nothing) <= 0.0
129+
130+
solver = init(probB, Bisection(;exact_left = true, exact_right = true); immutable = false)
131+
sol = solve!(solver)
132+
@test f(sol.left, nothing) < 0.0
133+
@test f(nextfloat(sol.left), nothing) >= 0.0
134+
@test f(sol.right, nothing) > 0.0
135+
@test f(prevfloat(sol.right), nothing) <= 0.0

0 commit comments

Comments
 (0)