Skip to content

Commit 5b72760

Browse files
committed
Non-mutable Iterator
1 parent 87968d9 commit 5b72760

File tree

7 files changed

+174
-8
lines changed

7 files changed

+174
-8
lines changed

Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,13 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
88
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
11+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
12+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1113
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
14+
15+
[extras]
16+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
17+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18+
19+
[targets]
20+
test = ["BenchmarkTools", "Test"]

src/NonlinearSolve.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@ module NonlinearSolve
44
@reexport using DiffEqBase
55
using UnPack: @unpack
66
using FiniteDiff, ForwardDiff
7+
using Setfield
8+
using StaticArrays
79

810
abstract type AbstractNonlinearSolveAlgorithm end
911
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
1012
abstract type AbstractNewtonAlgorithm{CS,AD} <: AbstractNonlinearSolveAlgorithm end
1113
abstract type AbstractNonlinearSolver end
14+
abstract type AbstractImmutableNonlinearSolver <: AbstractNonlinearSolver end
1215

1316
include("jacobian.jl")
1417
include("types.jl")
@@ -25,4 +28,6 @@ module NonlinearSolve
2528
# DiffEq styled algorithms
2629
export Bisection, Falsi, NewtonRaphson
2730
export ScalarBisection, ScalarNewton
31+
32+
export reinit!
2833
end # module

src/jacobian.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ end
66
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
77
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
88

9+
mutable struct ImmutableJacobianWrapper{fType, pType}
10+
f::fType
11+
p::pType
12+
end
13+
14+
(uf::ImmutableJacobianWrapper)(u) = uf.f(u, uf.p)
15+
916
function calc_J(solver, cache)
1017
@unpack u, f, p, alg = solver
1118
@unpack uf = cache
@@ -15,9 +22,17 @@ function calc_J(solver, cache)
1522
return J
1623
end
1724

25+
function calc_J(solver, uf::ImmutableJacobianWrapper)
26+
@unpack u, f, p, alg = solver
27+
@set! uf.f = f
28+
@set! uf.p = p
29+
J = jacobian(uf, u, solver)
30+
return J
31+
end
32+
1833
function jacobian(f, x, solver)
1934
if alg_autodiff(solver.alg)
20-
J = ForwardDiff.derivative(f, x)
35+
J = ForwardDiff.jacobian(f, Ref(x)[])
2136
else
2237
J = FiniteDiff.finite_difference_derivative(f, x, solver.alg.diff_type, eltype(x))
2338
end

src/raphson.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
1515
jac_config::JC
1616
end
1717

18-
mutable struct NewtonRaphsonConstantCache{ufType}
18+
struct NewtonRaphsonConstantCache{ufType}
1919
uf::ufType
2020
end
2121

@@ -48,7 +48,7 @@ function perform_step!(solver, alg::NewtonRaphson, cache::NewtonRaphsonConstantC
4848
J = calc_J(solver, cache)
4949
solver.u = u - J \ fu
5050
solver.fu = f(solver.u, p)
51-
if iszero(solver.fu) || abs(solver.fu) < solver.tol
51+
if iszero(solver.fu) || solver.internalnorm(solver.fu) < solver.tol
5252
solver.force_stop = true
5353
end
5454
end
@@ -65,3 +65,15 @@ function perform_step!(solver, alg::NewtonRaphson, cache::NewtonRaphsonCache)
6565
solver.force_stop = true
6666
end
6767
end
68+
69+
function perform_step(solver, alg::NewtonRaphson)
70+
@unpack u, fu, f, p = solver
71+
J = calc_J(solver, ImmutableJacobianWrapper(f, p))
72+
@set! solver.u = u - J \ fu
73+
fu = f(solver.u, p)
74+
@set! solver.fu = fu
75+
if iszero(solver.fu) || solver.internalnorm(solver.fu) < solver.tol
76+
@set! solver.force_stop = true
77+
end
78+
return solver
79+
end

src/solve.jl

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ end
99
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
1010
alias_u0 = false,
1111
maxiters = 1000,
12+
immutable = (prob.u0 isa StaticArray || prob.u0 isa Number),
1213
kwargs...
1314
) where {uType, iip}
1415

@@ -29,12 +30,17 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracke
2930
cache = alg_cache(alg, left, right, p, Val(iip))
3031

3132
sol = build_solution(left, Val(iip))
32-
return BracketingSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default, sol)
33+
if immutable
34+
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default, sol)
35+
else
36+
return BracketingSolver(1, f, alg, left, right, fl, fr, p, cache, false, maxiters, :Default, sol)
37+
end
3338
end
3439

3540
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
3641
alias_u0 = false,
3742
maxiters = 1000,
43+
immutable = (prob.u0 isa StaticArray || prob.u0 isa Number),
3844
tol = 1e-6,
3945
internalnorm = Base.Fix2(DiffEqBase.ODE_DEFAULT_NORM, nothing),
4046
kwargs...
@@ -54,10 +60,14 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewton
5460
fu = f(u, p)
5561
end
5662

57-
cache = alg_cache(alg, f, u, p, Val(iip))
58-
63+
5964
sol = build_newton_solution(u, Val(iip))
60-
return NewtonSolver(1, f, alg, u, fu, p, cache, false, maxiters, internalnorm, :Default, tol, sol)
65+
if immutable
66+
return NewtonImmutableSolver(1, f, alg, u, fu, p, nothing, false, maxiters, internalnorm, :Default, tol, sol)
67+
else
68+
cache = alg_cache(alg, f, u, p, Val(iip))
69+
return NewtonSolver(1, f, alg, u, fu, p, cache, false, maxiters, internalnorm, :Default, tol, sol)
70+
end
6171
end
6272

6373
function DiffEqBase.solve!(solver::AbstractNonlinearSolver)
@@ -75,6 +85,21 @@ function DiffEqBase.solve!(solver::AbstractNonlinearSolver)
7585
return solver.sol
7686
end
7787

88+
function DiffEqBase.solve!(solver::AbstractImmutableNonlinearSolver)
89+
# sync_residuals!(solver)
90+
solver = mic_check(solver)
91+
while !solver.force_stop && solver.iter < solver.maxiters
92+
solver = perform_step(solver, solver.alg)
93+
@set! solver.iter += 1
94+
# sync_residuals!(solver)
95+
end
96+
if solver.iter == solver.maxiters
97+
@set! solver.retcode = :MaxitersExceeded
98+
end
99+
sol = get_solution(solver)
100+
return sol
101+
end
102+
78103
function mic_check!(solver::BracketingSolver)
79104
@unpack f, fl, fr = solver
80105
flr = fl * fr
@@ -90,10 +115,29 @@ function mic_check!(solver::BracketingSolver)
90115
nothing
91116
end
92117

118+
function mic_check(solver::BracketingImmutableSolver)
119+
@unpack f, fl, fr = solver
120+
flr = fl * fr
121+
fzero = zero(flr)
122+
(flr > fzero) && error("Non bracketing interval passed in bracketing method.")
123+
if fl == fzero
124+
@set! solver.force_stop = true
125+
@set! solver.retcode = :ExactSolutionAtLeft
126+
elseif fr == fzero
127+
@set! solver.force_stop = true
128+
@set! solver.retcode = :ExactSolutionAtRight
129+
end
130+
solver
131+
end
132+
93133
function mic_check!(solver::NewtonSolver)
94134
nothing
95135
end
96136

137+
function mic_check(solver::NewtonImmutableSolver)
138+
solver
139+
end
140+
97141
function check_for_exact_solution!(solver::BracketingSolver)
98142
@unpack fl, fr = solver
99143
fzero = zero(fl)
@@ -113,7 +157,27 @@ function set_solution!(solver::BracketingSolver)
113157
solver.sol.retcode = solver.retcode
114158
end
115159

160+
function get_solution(solver::BracketingImmutableSolver)
161+
return (left = solver.left, right = solver.right, retcode = solver.retcode)
162+
end
163+
116164
function set_solution!(solver::NewtonSolver)
117165
solver.sol.u = solver.u
118166
solver.sol.retcode = solver.retcode
119-
end
167+
end
168+
169+
function get_solution(solver::NewtonImmutableSolver)
170+
return (u = solver.u, retcode = solver.retcode)
171+
end
172+
173+
function reinit!(solver::NewtonSolver, prob::NonlinearProblem{uType, true}) where {uType}
174+
@. solver.u = prob.u0
175+
solver.iter = 1
176+
solver.force_stop = false
177+
end
178+
179+
function reinit!(solver::NewtonSolver, prob::NonlinearProblem{uType, false}) where {uType}
180+
solver.u = prob.u0
181+
solver.iter = 1
182+
solver.force_stop = false
183+
end

src/types.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,23 @@ mutable struct BracketingSolver{fType, algType, uType, resType, pType, cacheType
1414
sol::solType
1515
end
1616

17+
struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, solType} <: AbstractImmutableNonlinearSolver
18+
iter::Int
19+
f::fType
20+
alg::algType
21+
left::uType
22+
right::uType
23+
fl::resType
24+
fr::resType
25+
p::pType
26+
cache::cacheType
27+
force_stop::Bool
28+
maxiters::Int
29+
retcode::Symbol
30+
sol::solType
31+
end
32+
33+
1734
mutable struct NewtonSolver{fType, algType, uType, resType, pType, cacheType, INType, tolType, solType} <: AbstractNonlinearSolver
1835
iter::Int
1936
f::fType
@@ -30,6 +47,22 @@ mutable struct NewtonSolver{fType, algType, uType, resType, pType, cacheType, IN
3047
sol::solType
3148
end
3249

50+
struct NewtonImmutableSolver{fType, algType, uType, resType, pType, cacheType, INType, tolType, solType} <: AbstractImmutableNonlinearSolver
51+
iter::Int
52+
f::fType
53+
alg::algType
54+
u::uType
55+
fu::resType
56+
p::pType
57+
cache::cacheType
58+
force_stop::Bool
59+
maxiters::Int
60+
internalnorm::INType
61+
retcode::Symbol
62+
tol::tolType
63+
sol::solType
64+
end
65+
3366
function sync_residuals!(solver::BracketingSolver)
3467
solver.fl = solver.f(solver.left, solver.p)
3568
solver.fr = solver.f(solver.right, solver.p)

test/runtests.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using NonlinearSolve
2+
using StaticArrays
3+
using BenchmarkTools
4+
using Test
5+
6+
function benchmark_immutable()
7+
probN = NonlinearProblem((u,p) -> u .* u .- 2, @SVector[1.0, 1.0])
8+
solver = init(probN, NewtonRaphson(), immutable = true, tol = 1e-9)
9+
sol = @btime solve!($solver)
10+
@test all(sol.u .* sol.u .- 2 .< 1e-9)
11+
end
12+
13+
function benchmark_mutable()
14+
probN = NonlinearProblem((u,p) -> u .* u .- 2, @SVector[1.0, 1.0])
15+
solver = init(probN, NewtonRaphson(), immutable = false, tol = 1e-9)
16+
sol = @btime (reinit!($solver, $probN); solve!($solver))
17+
@test all(sol.u .* sol.u .- 2 .< 1e-9)
18+
end
19+
20+
function benchmark_scalar()
21+
probN = NonlinearProblem((u,p) -> u .* u .- 2, 1.0)
22+
sol = @btime (solve($probN, ScalarNewton()))
23+
@test sol * sol - 2 < 1e-9
24+
end
25+
26+
benchmark_immutable()
27+
benchmark_mutable()
28+
benchmark_scalar()

0 commit comments

Comments
 (0)