Skip to content

Commit 4cfce94

Browse files
committed
DiffEq Style
1 parent ee8fb32 commit 4cfce94

File tree

6 files changed

+233
-2
lines changed

6 files changed

+233
-2
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,8 @@ name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["Kanav Gupta <[email protected]>"]
44
version = "0.1.0"
5+
6+
[deps]
7+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
9+
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

src/NonlinearSolve.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
11
module NonlinearSolve
22

3+
using Reexport
4+
@reexport using DiffEqBase
5+
using UnPack: @unpack
6+
7+
abstract type AbstractNonlinearSolveAlgorithm end
8+
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
9+
10+
include("types.jl")
11+
include("solve.jl")
312
include("utils.jl")
413
include("bisection.jl")
514
include("falsi.jl")
615

16+
# raw methods
717
export bisection, falsi
18+
19+
# DiffEq styled algorithms
20+
export Bisection, Falsi
821
end # module

src/bisection.jl

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
struct Bisection <: AbstractBracketingAlgorithm
2+
end
3+
4+
function alg_cache(alg::Bisection, left, right, p, ::Val{true})
5+
nothing
6+
end
7+
8+
function alg_cache(alg::Bisection, left, right, p, ::Val{false})
9+
nothing
10+
end
11+
112
"""
213
bisection(f, tup ; maxiters=1000)
314
@@ -30,8 +41,8 @@ function bisection(f, tup ; maxiters=1000)
3041
fl = f(left)
3142
fr = f(right)
3243

33-
fl * fr >= fzero && error("Bracket became non-containing in between iterations. This could mean that"
34-
+ " input function crosses the x axis multiple times. Bisection is not the right method to solve this.")
44+
fl * fr >= fzero && error("Bracket became non-containing in between iterations. This could mean that "
45+
+ "input function crosses the x axis multiple times. Bisection is not the right method to solve this.")
3546

3647
mid = (left + right) / 2.0
3748
fm = f(mid)
@@ -69,3 +80,37 @@ function bisection(f, tup ; maxiters=1000)
6980
end
7081
end
7182
end
83+
84+
function perform_step!(solver, alg::Bisection, cache)
85+
@unpack f, p, left, right, fl, fr = solver
86+
87+
fzero = zero(fl)
88+
fl * fr > fzero && error("Bracket became non-containing in between iterations. This could mean that "
89+
+ "input function crosses the x axis multiple times. Bisection is not the right method to solve this.")
90+
91+
mid = (left + right) / 2.0
92+
93+
if right == mid || right == mid
94+
solver.force_stop = true
95+
solver.retcode = :FloatingPointLimit
96+
return
97+
end
98+
99+
fm = f(mid, p)
100+
101+
if iszero(fm)
102+
# todo: phase 2 bisection similar to the raw method
103+
solver.force_stop = true
104+
solver.left = mid
105+
solver.fl = fm
106+
solver.retcode = :ExactSolutionAtLeft
107+
else
108+
if sign(fm) == sign(fl)
109+
solver.left = mid
110+
solver.fl = fm
111+
else
112+
solver.right = mid
113+
solver.fr = fm
114+
end
115+
end
116+
end

src/falsi.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
struct Falsi <: AbstractBracketingAlgorithm
2+
end
3+
4+
function alg_cache(alg::Falsi, left, right, p, ::Val{true})
5+
nothing
6+
end
7+
8+
function alg_cache(alg::Falsi, left, right, p, ::Val{false})
9+
nothing
10+
end
11+
12+
113
"""
214
falsi(f, tup ; maxiters=1000)
315
@@ -72,3 +84,38 @@ function falsi(f, tup ; maxiters=1000)
7284
end
7385
end
7486
end
87+
88+
function perform_step!(solver, alg::Falsi, cache)
89+
@unpack f, p, left, right, fl, fr = solver
90+
91+
fzero = zero(fl)
92+
fl * fr > fzero && error("Bracket became non-containing in between iterations. This could mean that "
93+
+ "input function crosses the x axis multiple times. Bisection is not the right method to solve this.")
94+
95+
mid = (fr * left - fl * right) / (fr - fl)
96+
97+
if right == mid || right == mid
98+
solver.force_stop = true
99+
solver.retcode = :FloatingPointLimit
100+
return nothing
101+
end
102+
103+
fm = f(mid, p)
104+
105+
if iszero(fm)
106+
# todo: phase 2 bisection similar to the raw method
107+
solver.force_stop = true
108+
solver.left = mid
109+
solver.fl = fm
110+
solver.retcode = :ExactSolutionAtLeft
111+
else
112+
if sign(fm) == sign(fl)
113+
solver.left = mid
114+
solver.fl = fm
115+
else
116+
solver.right = mid
117+
solver.fr = fm
118+
end
119+
end
120+
return nothing
121+
end

src/solve.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
function DiffEqBase.__solve(prob::NonlinearProblem,
2+
alg::AbstractNonlinearSolveAlgorithm, args...;
3+
kwargs...)
4+
solver = DiffEqBase.__init(prob, alg, args...; kwargs...)
5+
solve!(solver)
6+
return solver.sol
7+
end
8+
9+
function DiffEqBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
10+
alias_u0 = false,
11+
maxiters = 1000,
12+
kwargs...
13+
) where {uType, iip}
14+
15+
if !(prob.u0 isa Tuple)
16+
error("You need to pass a tuple of u0 in bracketing algorithms.")
17+
end
18+
19+
if alias_u0
20+
left, right = prob.u0
21+
else
22+
left, right = deepcopy(prob.u0)
23+
end
24+
f = prob.f
25+
p = prob.p
26+
fl = f(left, p)
27+
fr = f(right, p)
28+
29+
cache = alg_cache(alg, left, right, p, Val(iip))
30+
31+
sol = build_solution(left, Val(iip))
32+
return BracketingSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default, sol)
33+
end
34+
35+
function DiffEqBase.solve!(solver::BracketingSolver)
36+
# sync_residuals!(solver)
37+
mic_check!(solver)
38+
while !solver.force_stop && solver.iter < solver.maxiters
39+
if check_for_exact_solution!(solver)
40+
break
41+
else
42+
perform_step!(solver, solver.alg, solver.cache)
43+
solver.iter += 1
44+
end
45+
# sync_residuals!(solver)
46+
end
47+
if solver.iter == solver.maxiters
48+
solver.retcode = :MaxitersExceeded
49+
end
50+
set_solution!(solver)
51+
return solver.sol
52+
end
53+
54+
function mic_check!(solver::BracketingSolver)
55+
@unpack f, fl, fr = solver
56+
flr = fl * fr
57+
fzero = zero(flr)
58+
(flr > fzero) && error("Non bracketing interval passed in bracketing method.")
59+
if fl == fzero
60+
solver.force_stop = true
61+
solver.retcode = :ExactSolutionAtLeft
62+
elseif fr == fzero
63+
solver.force_stop = true
64+
solver.retcode = :ExactSolutionAtRight
65+
end
66+
nothing
67+
end
68+
69+
function check_for_exact_solution!(solver::BracketingSolver)
70+
@unpack fl, fr = solver
71+
fzero = zero(fl)
72+
if fl == fzero
73+
solver.retcode = :ExactSolutionAtLeft
74+
return true
75+
elseif fr == fzero
76+
solver.retcode = :ExactSolutionAtRight
77+
return true
78+
end
79+
return false
80+
end
81+
82+
function set_solution!(solver)
83+
solver.sol.left = solver.left
84+
solver.sol.right = solver.right
85+
solver.sol.retcode = solver.retcode
86+
end

src/types.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
mutable struct BracketingSolver{fType, algType, uType, resType, pType, cacheType, solType}
2+
iter::Int
3+
f::fType
4+
alg::algType
5+
left::uType
6+
right::uType
7+
fl::resType
8+
fr::resType
9+
p::pType
10+
cache::cacheType
11+
force_stop::Bool
12+
maxiters::Int
13+
retcode::Symbol
14+
sol::solType
15+
end
16+
17+
function sync_residuals!(solver::BracketingSolver)
18+
solver.fl = solver.f(solver.left, solver.p)
19+
solver.fr = solver.f(solver.right, solver.p)
20+
nothing
21+
end
22+
23+
mutable struct BracketingSolution{uType}
24+
left::uType
25+
right::uType
26+
retcode::Symbol
27+
end
28+
29+
function build_solution(u_prototype, ::Val{true})
30+
return BracketingSolution(similar(u_prototype), similar(u_prototype), :Default)
31+
end
32+
33+
function build_solution(u_prototype, ::Val{false})
34+
return BracketingSolution(zero(u_prototype), zero(u_prototype), :Default)
35+
end

0 commit comments

Comments
 (0)