Skip to content

Commit f3875f6

Browse files
Merge pull request #173 from dawbarton/master
Added reinit! for NewtonRaphson
2 parents 85b94ac + cf1486f commit f3875f6

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

src/raphson.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,22 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache)
198198
SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu;
199199
retcode = cache.retcode)
200200
end
201+
202+
function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u0; p = cache.p,
203+
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
204+
cache.p = p
205+
if iip
206+
recursivecopy!(cache.u, u0)
207+
cache.f(cache.fu, cache.u, p)
208+
else
209+
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
210+
cache.u = u0
211+
cache.fu = cache.f(cache.u, p)
212+
end
213+
cache.abstol = abstol
214+
cache.maxiters = maxiters
215+
cache.iter = 1
216+
cache.force_stop = false
217+
cache.retcode = ReturnCode.Default
218+
return cache
219+
end

test/basictests.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,37 @@ end
111111
@test gnewton(p) [sqrt(p[2] / p[1])]
112112
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
113113

114+
# Iterator interface
115+
f = (u, p) -> u * u - p
116+
g = function (p_range)
117+
probN = NonlinearProblem{false}(f, 0.5, p_range[begin])
118+
cache = init(probN, NewtonRaphson(); maxiters = 100, abstol=1e-10)
119+
sols = zeros(length(p_range))
120+
for (i, p) in enumerate(p_range)
121+
reinit!(cache, cache.u; p = p)
122+
sol = solve!(cache)
123+
sols[i] = sol.u
124+
end
125+
return sols
126+
end
127+
p = range(0.01, 2, length = 200)
128+
@test g(p) sqrt.(p)
129+
130+
f = (res, u, p) -> (res[begin] = u[1] * u[1] - p)
131+
g = function (p_range)
132+
probN = NonlinearProblem{true}(f, [0.5], p_range[begin])
133+
cache = init(probN, NewtonRaphson(); maxiters = 100, abstol=1e-10)
134+
sols = zeros(length(p_range))
135+
for (i, p) in enumerate(p_range)
136+
reinit!(cache, [cache.u[1]]; p = p)
137+
sol = solve!(cache)
138+
sols[i] = sol.u[1]
139+
end
140+
return sols
141+
end
142+
p = range(0.01, 2, length = 200)
143+
@test g(p) sqrt.(p)
144+
114145
# Error Checks
115146

116147
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]

0 commit comments

Comments
 (0)