Skip to content

Commit 87968d9

Browse files
committed
(feat) IIP Raphson
1 parent 2a9c162 commit 87968d9

File tree

5 files changed

+77
-11
lines changed

5 files changed

+77
-11
lines changed

src/jacobian.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,23 @@ function jacobian(f, x, solver)
2323
end
2424
return J
2525
end
26+
27+
function calc_J!(J, solver, cache)
28+
@unpack f, u, p, alg = solver
29+
@unpack du1, uf, jac_config = cache
30+
31+
uf.f = f
32+
uf.p = p
33+
34+
jacobian!(J, uf, u, du1, solver, jac_config)
35+
end
36+
37+
function jacobian!(J, f, x, fx, solver, jac_config)
38+
alg = solver.alg
39+
if alg_autodiff(alg)
40+
ForwardDiff.jacobian!(J, f, fx, x, jac_config)
41+
else
42+
FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config)
43+
end
44+
nothing
45+
end

src/raphson.jl

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,67 @@
1-
struct NewtonRaphson{CS, AD, DT} <: AbstractNewtonAlgorithm{CS,AD}
1+
struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS,AD}
22
diff_type::DT
3+
linsolve::L
34
end
45

5-
function NewtonRaphson(;autodiff=true,chunk_size=12,diff_type=Val{:forward})
6-
NewtonRaphson{chunk_size, autodiff, typeof(diff_type)}(diff_type)
6+
function NewtonRaphson(;autodiff=true,chunk_size=12,diff_type=Val{:forward},linsolve=DEFAULT_LINSOLVE)
7+
NewtonRaphson{chunk_size, autodiff, typeof(diff_type), typeof(linsolve)}(diff_type, linsolve)
78
end
89

9-
mutable struct NewtonRaphsonCache{ufType}
10+
mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
11+
uf::ufType
12+
linsolve::L
13+
J::jType
14+
du1::uType
15+
jac_config::JC
16+
end
17+
18+
mutable struct NewtonRaphsonConstantCache{ufType}
1019
uf::ufType
1120
end
1221

1322
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true})
1423
uf = JacobianWrapper(f,p)
15-
NewtonRaphsonCache(uf)
24+
linsolve = alg.linsolve(Val{:init}, f, u)
25+
J = false .* u .* u'
26+
du1 = zero(u)
27+
tmp = zero(u)
28+
if alg_autodiff(alg)
29+
jac_config = ForwardDiff.JacobianConfig(uf, du1, u)
30+
else
31+
if alg.diff_type != Val{:complex}
32+
du2 = zero(u)
33+
jac_config = FiniteDiff.JacobianCache(tmp, du1, du2, alg.diff_type)
34+
else
35+
jac_config = FiniteDiff.JacobianCache(Complex{eltype(tmp)}.(tmp),Complex{eltype(du1)}.(du1),nothing,alg.diff_type,eltype(u))
36+
end
37+
end
38+
NewtonRaphsonCache(uf, linsolve, J, du1, jac_config)
1639
end
1740

1841
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{false})
1942
uf = JacobianWrapper(f,p)
20-
NewtonRaphsonCache(uf)
43+
NewtonRaphsonConstantCache(uf)
2144
end
2245

23-
function perform_step!(solver, alg::NewtonRaphson, cache)
46+
function perform_step!(solver, alg::NewtonRaphson, cache::NewtonRaphsonConstantCache)
2447
@unpack u, fu, f, p = solver
2548
J = calc_J(solver, cache)
2649
solver.u = u - J \ fu
2750
solver.fu = f(solver.u, p)
2851
if iszero(solver.fu) || abs(solver.fu) < solver.tol
2952
solver.force_stop = true
3053
end
31-
end
54+
end
55+
56+
function perform_step!(solver, alg::NewtonRaphson, cache::NewtonRaphsonCache)
57+
@unpack u, fu, f, p = solver
58+
@unpack J, linsolve, du1 = cache
59+
calc_J!(J, solver, cache)
60+
# u = u - J \ fu
61+
linsolve(du1, J, fu, true)
62+
@. u = u - du1
63+
f(fu, u, p)
64+
if solver.internalnorm(solver.fu) < solver.tol
65+
solver.force_stop = true
66+
end
67+
end

src/solve.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewton
3636
alias_u0 = false,
3737
maxiters = 1000,
3838
tol = 1e-6,
39+
internalnorm = Base.Fix2(DiffEqBase.ODE_DEFAULT_NORM, nothing),
3940
kwargs...
4041
) where {uType, iip}
4142

@@ -46,12 +47,17 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewton
4647
end
4748
f = prob.f
4849
p = prob.p
49-
fu = f(u, p)
50+
if iip
51+
fu = zero(u)
52+
f(fu, u, p)
53+
else
54+
fu = f(u, p)
55+
end
5056

5157
cache = alg_cache(alg, f, u, p, Val(iip))
5258

5359
sol = build_newton_solution(u, Val(iip))
54-
return NewtonSolver(1, f, alg, u, fu, p, cache, false, maxiters, :Default, tol, sol)
60+
return NewtonSolver(1, f, alg, u, fu, p, cache, false, maxiters, internalnorm, :Default, tol, sol)
5561
end
5662

5763
function DiffEqBase.solve!(solver::AbstractNonlinearSolver)

src/types.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ mutable struct BracketingSolver{fType, algType, uType, resType, pType, cacheType
1414
sol::solType
1515
end
1616

17-
mutable struct NewtonSolver{fType, algType, uType, resType, pType, cacheType, tolType, solType} <: AbstractNonlinearSolver
17+
mutable struct NewtonSolver{fType, algType, uType, resType, pType, cacheType, INType, tolType, solType} <: AbstractNonlinearSolver
1818
iter::Int
1919
f::fType
2020
alg::algType
@@ -24,6 +24,7 @@ mutable struct NewtonSolver{fType, algType, uType, resType, pType, cacheType, to
2424
cache::cacheType
2525
force_stop::Bool
2626
maxiters::Int
27+
internalnorm::INType
2728
retcode::Symbol
2829
tol::tolType
2930
sol::solType

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@ function value_derivative(f::F, x::R) where {F,R}
2424
out = f(ForwardDiff.Dual{T}(x, one(x)))
2525
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
2626
end
27+
28+
DiffEqBase.has_Wfact(f::Function) = false
29+
DiffEqBase.has_Wfact_t(f::Function) = false

0 commit comments

Comments
 (0)