Skip to content

Commit 2a9c162

Browse files
committed
(feat) non iterator solvers for oop
1 parent c695450 commit 2a9c162

File tree

6 files changed

+102
-155
lines changed

6 files changed

+102
-155
lines changed

src/NonlinearSolve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@ module NonlinearSolve
1212

1313
include("jacobian.jl")
1414
include("types.jl")
15-
include("solve.jl")
1615
include("utils.jl")
16+
include("solve.jl")
1717
include("bisection.jl")
1818
include("falsi.jl")
1919
include("raphson.jl")
20+
include("scalar.jl")
2021

2122
# raw methods
2223
export bisection, falsi
2324

2425
# DiffEq styled algorithms
2526
export Bisection, Falsi, NewtonRaphson
27+
export ScalarBisection, ScalarNewton
2628
end # module

src/bisection.jl

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -21,78 +21,6 @@ function alg_cache(alg::Bisection, left, right, p, ::Val{false})
2121
BisectionCache(UInt8(0), left, right)
2222
end
2323

24-
"""
25-
bisection(f, tup ; maxiters=1000)
26-
27-
Uses bisection method to find the root of the function `f` between a tuple `tup` of values.
28-
"""
29-
function bisection(f, tup ; maxiters=1000)
30-
x0, x1 = tup
31-
fx0, fx1 = f(x0), f(x1)
32-
fx0x1 = fx0 * fx1
33-
fzero = zero(fx0x1)
34-
35-
(fx0x1 > fzero) && error("Non bracketing interval passed in bisection method.")
36-
# NOTE: fx0x1 = 0 can mean that both fx0 and fx1 are very small and multiplication of them
37-
# could be less than the smallest float, hence could be zero.
38-
39-
fx0 == fzero && return x0 # should replace with some tolerance compare i.e. ≈
40-
fx1 == fzero && return x1
41-
42-
left = x0
43-
right = x1
44-
45-
iter = 0
46-
while true
47-
iter += 1
48-
49-
if iter == maxiters
50-
return left
51-
end
52-
53-
fl = f(left)
54-
fr = f(right)
55-
56-
fl * fr >= fzero && error("Bracket became non-containing in between iterations. This could mean that "
57-
+ "input function crosses the x axis multiple times. Bisection is not the right method to solve this.")
58-
59-
mid = (left + right) / 2.0
60-
fm = f(mid)
61-
if iszero(fm)
62-
# we are in the region of zero, inner loop
63-
right = mid
64-
while true
65-
iter += 1
66-
67-
if iter == maxiters
68-
return left
69-
end
70-
71-
mid = (left + right) / 2.0
72-
(left == mid || right == mid) && return left
73-
fm = f(mid)
74-
75-
if iszero(fm)
76-
if !iszero(f(prevfloat_tdir(mid, x0, x1)))
77-
return prevfloat_tdir(mid, x0, x1)
78-
end
79-
right = mid
80-
else
81-
left = mid
82-
end
83-
84-
end
85-
end
86-
87-
(left == mid || right == mid) && return left
88-
if sign(fm) == sign(fl)
89-
left = mid
90-
else
91-
right = mid
92-
end
93-
end
94-
end
95-
9624
function perform_step!(solver, alg::Bisection, cache)
9725
@unpack f, p, left, right, fl, fr = solver
9826

src/falsi.jl

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -9,82 +9,6 @@ function alg_cache(alg::Falsi, left, right, p, ::Val{false})
99
nothing
1010
end
1111

12-
13-
"""
14-
falsi(f, tup ; maxiters=1000)
15-
16-
Uses Regula-Falsi method to find the root of the function `f` between a tuple `tup = (x0, x1)` of values.
17-
It doesn't find the exact value at which f(x) = 0, but finds an x where the next float in the direction of `x1`
18-
gives the evaluation of `f` to be zero. If `f` is zero for a region of x, it returns the left-most (in the
19-
direction of x0) such value.
20-
"""
21-
function falsi(f, tup ; maxiters=1000)
22-
x0, x1 = tup
23-
fx0, fx1 = f(x0), f(x1)
24-
fx0x1 = fx0 * fx1
25-
fzero = zero(fx0x1)
26-
27-
(fx0x1 > fzero) && error("Non bracketing interval passed in bisection method.")
28-
# NOTE: fx0x1 = 0 can mean that both fx0 and fx1 are very small and multiplication of them
29-
# could be less than the smallest float, hence could be zero.
30-
31-
fx0 == fzero && return x0 # should replace with some tolerance compare i.e. ≈
32-
fx1 == fzero && return x1
33-
34-
left = x0
35-
right = x1
36-
37-
iter = 0
38-
while true
39-
iter += 1
40-
41-
if iter == maxiters
42-
return left
43-
end
44-
45-
fl = f(left)
46-
fr = f(right)
47-
48-
fl * fr >= fzero && error("Bracket became non-containing in between iterations. This could mean that"
49-
+ " input function crosses the x axis multiple times. Bisection is not the right method to solve this.")
50-
51-
mid = (fr * left - fl * right) / (fr - fl)
52-
fm = f(mid)
53-
if iszero(fm)
54-
# we are in the region of zero, inner loop
55-
right = mid
56-
while true
57-
iter += 1
58-
59-
if iter == maxiters
60-
return left
61-
end
62-
63-
mid = (left + right) / 2.0
64-
(left == mid || right == mid) && return left
65-
fm = f(mid)
66-
67-
if iszero(fm)
68-
if !iszero(f(prevfloat_tdir(mid, x0, x1)))
69-
return prevfloat_tdir(mid, x0, x1)
70-
end
71-
right = mid
72-
else
73-
left = mid
74-
end
75-
76-
end
77-
end
78-
79-
(left == mid || right == mid) && return left
80-
if sign(fm) == sign(fl)
81-
left = mid
82-
else
83-
right = mid
84-
end
85-
end
86-
end
87-
8812
function perform_step!(solver, alg::Falsi, cache)
8913
@unpack f, p, left, right, fl, fr = solver
9014

src/scalar.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
ScalarNewton
3+
4+
Fast Newton Raphson for scalar problems.
5+
"""
6+
struct ScalarNewton <: AbstractNonlinearSolveAlgorithm end
7+
8+
function DiffEqBase.solve(prob::NonlinearProblem{uType, false}, ::ScalarNewton, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...) where {uType}
9+
f = Base.Fix2(prob.f, prob.p)
10+
x = float(prob.u0)
11+
T = typeof(x)
12+
atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5)
13+
rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5)
14+
15+
xo = oftype(x, Inf)
16+
for i in 1:maxiters
17+
fx, dfx = value_derivative(f, x)
18+
iszero(fx) && return x
19+
Δx = dfx \ fx
20+
x -= Δx
21+
if isapprox(x, xo, atol=atol, rtol=rtol)
22+
return x
23+
end
24+
xo = x
25+
end
26+
return oftype(x, NaN)
27+
end
28+
29+
"""
30+
ScalarBisection
31+
32+
Fast Bisection for scalar problems. Note that it doesn't returns exact solution, but returns
33+
the best left limit of the exact solution.
34+
"""
35+
struct ScalarBisection <: AbstractNonlinearSolveAlgorithm end
36+
37+
function DiffEqBase.solve(prob::NonlinearProblem{uType, false}, ::ScalarBisection, args...; maxiters = 1000, kwargs...) where {uType}
38+
f = Base.Fix2(prob.f, prob.p)
39+
left, right = prob.u0
40+
fl, fr = f(left), f(right)
41+
42+
if iszero(fl)
43+
return fl
44+
end
45+
46+
i = 1
47+
if !iszero(fr)
48+
while i < maxiters
49+
mid = (left + right) / 2
50+
(mid == left || mid == right) && return left
51+
fm = f(mid)
52+
if iszero(fm)
53+
right = mid
54+
break
55+
end
56+
if sign(fl) == sign(fm)
57+
fl = fm
58+
left = mid
59+
else
60+
fr = fm
61+
right = mid
62+
end
63+
i += 1
64+
end
65+
end
66+
67+
while i < maxiters
68+
mid = (left + right) / 2
69+
(mid == left || mid == right) && return left
70+
fm = f(mid)
71+
if iszero(fm)
72+
right = mid
73+
fr = fm
74+
else
75+
left = mid
76+
fl = fm
77+
end
78+
i += 1
79+
end
80+
81+
return left
82+
end

src/solve.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
function DiffEqBase.__solve(prob::NonlinearProblem,
2-
alg::AbstractNonlinearSolveAlgorithm, args...;
3-
kwargs...)
4-
solver = DiffEqBase.__init(prob, alg, args...; kwargs...)
1+
function DiffEqBase.solve(prob::NonlinearProblem,
2+
alg::AbstractNonlinearSolveAlgorithm, args...;
3+
kwargs...)
4+
solver = DiffEqBase.init(prob, alg, args...; kwargs...)
55
solve!(solver)
66
return solver.sol
77
end
88

9-
function DiffEqBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
9+
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
1010
alias_u0 = false,
1111
maxiters = 1000,
1212
kwargs...
@@ -32,7 +32,7 @@ function DiffEqBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractBrac
3232
return BracketingSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default, sol)
3333
end
3434

35-
function DiffEqBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
35+
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
3636
alias_u0 = false,
3737
maxiters = 1000,
3838
tol = 1e-6,

src/utils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,14 @@ end
1313

1414
alg_autodiff(alg::AbstractNewtonAlgorithm{CS,AD}) where {CS,AD} = AD
1515
alg_autodiff(alg) = false
16+
17+
"""
18+
value_derivative(f, x)
19+
20+
Compute `f(x), d/dx f(x)` in the most efficient way.
21+
"""
22+
function value_derivative(f::F, x::R) where {F,R}
23+
T = typeof(ForwardDiff.Tag(f, R))
24+
out = f(ForwardDiff.Dual{T}(x, one(x)))
25+
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
26+
end

0 commit comments

Comments
 (0)