Skip to content

Commit c695450

Browse files
committed
(feat) NewtonRaphson
1 parent 6f0a9a5 commit c695450

File tree

7 files changed

+128
-10
lines changed

7 files changed

+128
-10
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
9+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
810
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
911
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

src/NonlinearSolve.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,24 @@ module NonlinearSolve
33
using Reexport
44
@reexport using DiffEqBase
55
using UnPack: @unpack
6+
using FiniteDiff, ForwardDiff
67

78
abstract type AbstractNonlinearSolveAlgorithm end
89
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
10+
abstract type AbstractNewtonAlgorithm{CS,AD} <: AbstractNonlinearSolveAlgorithm end
11+
abstract type AbstractNonlinearSolver end
912

13+
include("jacobian.jl")
1014
include("types.jl")
1115
include("solve.jl")
1216
include("utils.jl")
1317
include("bisection.jl")
1418
include("falsi.jl")
19+
include("raphson.jl")
1520

1621
# raw methods
1722
export bisection, falsi
1823

1924
# DiffEq styled algorithms
20-
export Bisection, Falsi
25+
export Bisection, Falsi, NewtonRaphson
2126
end # module

src/jacobian.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
mutable struct JacobianWrapper{fType, pType}
2+
f::fType
3+
p::pType
4+
end
5+
6+
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
7+
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
8+
9+
function calc_J(solver, cache)
10+
@unpack u, f, p, alg = solver
11+
@unpack uf = cache
12+
uf.f = f
13+
uf.p = p
14+
J = jacobian(uf, u, solver)
15+
return J
16+
end
17+
18+
function jacobian(f, x, solver)
19+
if alg_autodiff(solver.alg)
20+
J = ForwardDiff.derivative(f, x)
21+
else
22+
J = FiniteDiff.finite_difference_derivative(f, x, solver.alg.diff_type, eltype(x))
23+
end
24+
return J
25+
end

src/raphson.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
struct NewtonRaphson{CS, AD, DT} <: AbstractNewtonAlgorithm{CS,AD}
2+
diff_type::DT
3+
end
4+
5+
function NewtonRaphson(;autodiff=true,chunk_size=12,diff_type=Val{:forward})
6+
NewtonRaphson{chunk_size, autodiff, typeof(diff_type)}(diff_type)
7+
end
8+
9+
mutable struct NewtonRaphsonCache{ufType}
10+
uf::ufType
11+
end
12+
13+
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true})
14+
uf = JacobianWrapper(f,p)
15+
NewtonRaphsonCache(uf)
16+
end
17+
18+
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{false})
19+
uf = JacobianWrapper(f,p)
20+
NewtonRaphsonCache(uf)
21+
end
22+
23+
function perform_step!(solver, alg::NewtonRaphson, cache)
24+
@unpack u, fu, f, p = solver
25+
J = calc_J(solver, cache)
26+
solver.u = u - J \ fu
27+
solver.fu = f(solver.u, p)
28+
if iszero(solver.fu) || abs(solver.fu) < solver.tol
29+
solver.force_stop = true
30+
end
31+
end

src/solve.jl

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,34 @@ function DiffEqBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractBrac
3232
return BracketingSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default, sol)
3333
end
3434

35-
function DiffEqBase.solve!(solver::BracketingSolver)
35+
function DiffEqBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
36+
alias_u0 = false,
37+
maxiters = 1000,
38+
tol = 1e-6,
39+
kwargs...
40+
) where {uType, iip}
41+
42+
if alias_u0
43+
u = prob.u0
44+
else
45+
u = deepcopy(prob.u0)
46+
end
47+
f = prob.f
48+
p = prob.p
49+
fu = f(u, p)
50+
51+
cache = alg_cache(alg, f, u, p, Val(iip))
52+
53+
sol = build_newton_solution(u, Val(iip))
54+
return NewtonSolver(1, f, alg, u, fu, p, cache, false, maxiters, :Default, tol, sol)
55+
end
56+
57+
function DiffEqBase.solve!(solver::AbstractNonlinearSolver)
3658
# sync_residuals!(solver)
3759
mic_check!(solver)
3860
while !solver.force_stop && solver.iter < solver.maxiters
39-
if check_for_exact_solution!(solver)
40-
break
41-
else
42-
perform_step!(solver, solver.alg, solver.cache)
43-
solver.iter += 1
44-
end
61+
perform_step!(solver, solver.alg, solver.cache)
62+
solver.iter += 1
4563
# sync_residuals!(solver)
4664
end
4765
if solver.iter == solver.maxiters
@@ -66,6 +84,10 @@ function mic_check!(solver::BracketingSolver)
6684
nothing
6785
end
6886

87+
function mic_check!(solver::NewtonSolver)
88+
nothing
89+
end
90+
6991
function check_for_exact_solution!(solver::BracketingSolver)
7092
@unpack fl, fr = solver
7193
fzero = zero(fl)
@@ -79,8 +101,13 @@ function check_for_exact_solution!(solver::BracketingSolver)
79101
return false
80102
end
81103

82-
function set_solution!(solver)
104+
function set_solution!(solver::BracketingSolver)
83105
solver.sol.left = solver.left
84106
solver.sol.right = solver.right
85107
solver.sol.retcode = solver.retcode
86108
end
109+
110+
function set_solution!(solver::NewtonSolver)
111+
solver.sol.u = solver.u
112+
solver.sol.retcode = solver.retcode
113+
end

src/types.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mutable struct BracketingSolver{fType, algType, uType, resType, pType, cacheType, solType}
1+
mutable struct BracketingSolver{fType, algType, uType, resType, pType, cacheType, solType} <: AbstractNonlinearSolver
22
iter::Int
33
f::fType
44
alg::algType
@@ -14,6 +14,21 @@ mutable struct BracketingSolver{fType, algType, uType, resType, pType, cacheType
1414
sol::solType
1515
end
1616

17+
mutable struct NewtonSolver{fType, algType, uType, resType, pType, cacheType, tolType, solType} <: AbstractNonlinearSolver
18+
iter::Int
19+
f::fType
20+
alg::algType
21+
u::uType
22+
fu::resType
23+
p::pType
24+
cache::cacheType
25+
force_stop::Bool
26+
maxiters::Int
27+
retcode::Symbol
28+
tol::tolType
29+
sol::solType
30+
end
31+
1732
function sync_residuals!(solver::BracketingSolver)
1833
solver.fl = solver.f(solver.left, solver.p)
1934
solver.fr = solver.f(solver.right, solver.p)
@@ -33,3 +48,13 @@ end
3348
function build_solution(u_prototype, ::Val{false})
3449
return BracketingSolution(zero(u_prototype), zero(u_prototype), :Default)
3550
end
51+
52+
mutable struct NewtonSolution{uType}
53+
u::uType
54+
retcode::Symbol
55+
end
56+
57+
function build_newton_solution(u_prototype, ::Val{iip}) where iip
58+
return NewtonSolution(zero(u_prototype), :Default)
59+
end
60+

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ end
1010
function nextfloat_tdir(x::T, x0::T, x1::T)::T where {T}
1111
x1 > x0 ? nextfloat(x) : prevfloat(x)
1212
end
13+
14+
alg_autodiff(alg::AbstractNewtonAlgorithm{CS,AD}) where {CS,AD} = AD
15+
alg_autodiff(alg) = false

0 commit comments

Comments
 (0)