Skip to content

Commit 08f7f1d

Browse files
Merge pull request #240 from SciML/default_cases
Change default in-place specialization
2 parents d327107 + 143c634 commit 08f7f1d

File tree

2 files changed

+86
-4
lines changed

2 files changed

+86
-4
lines changed

src/default.jl

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::FastShortcutN
138138
)
139139
end
140140

141-
function SciMLBase.__solve(prob::NonlinearProblem{uType, iip}, alg::FastShortcutNonlinearPolyalg, args...;
142-
kwargs...) where {uType, iip}
141+
function SciMLBase.__solve(prob::NonlinearProblem{uType, false}, alg::FastShortcutNonlinearPolyalg, args...;
142+
kwargs...) where {uType}
143143

144144
adkwargs = alg.adkwargs
145145
linsolve = alg.linsolve
@@ -190,6 +190,58 @@ function SciMLBase.__solve(prob::NonlinearProblem{uType, iip}, alg::FastShortcut
190190

191191
end
192192

193+
function SciMLBase.__solve(prob::NonlinearProblem{uType, true}, alg::FastShortcutNonlinearPolyalg, args...;
194+
kwargs...) where {uType}
195+
196+
adkwargs = alg.adkwargs
197+
linsolve = alg.linsolve
198+
precs = alg.precs
199+
200+
sol1 = SciMLBase.__solve(prob, NewtonRaphson(;linsolve, precs, adkwargs...), args...; kwargs...)
201+
if SciMLBase.successful_retcode(sol1)
202+
return SciMLBase.build_solution(prob, alg, sol1.u, sol1.resid;
203+
sol1.retcode, sol1.stats)
204+
end
205+
206+
sol2 = SciMLBase.__solve(prob, NewtonRaphson(;linsolve, precs, linesearch=BackTracking(), adkwargs...), args...; kwargs...)
207+
if SciMLBase.successful_retcode(sol2)
208+
return SciMLBase.build_solution(prob, alg, sol2.u, sol2.resid;
209+
sol2.retcode, sol2.stats)
210+
end
211+
212+
sol3 = SciMLBase.__solve(prob, TrustRegion(;linsolve, precs, adkwargs...), args...; kwargs...)
213+
if SciMLBase.successful_retcode(sol3)
214+
return SciMLBase.build_solution(prob, alg, sol3.u, sol3.resid;
215+
sol3.retcode, sol3.stats)
216+
end
217+
218+
sol4 = SciMLBase.__solve(prob, TrustRegion(;linsolve, precs, radius_update_scheme = RadiusUpdateSchemes.Bastin, adkwargs...), args...; kwargs...)
219+
if SciMLBase.successful_retcode(sol4)
220+
return SciMLBase.build_solution(prob, alg, sol4.u, sol4.resid;
221+
sol4.retcode, sol4.stats)
222+
end
223+
224+
resids = (sol1.resid, sol2.resid, sol3.resid, sol4.resid)
225+
minfu, idx = findmin(DEFAULT_NORM, resids)
226+
227+
if idx == 1
228+
SciMLBase.build_solution(prob, alg, sol1.u, sol1.resid;
229+
sol1.retcode, sol1.stats)
230+
elseif idx == 2
231+
SciMLBase.build_solution(prob, alg, sol2.u, sol2.resid;
232+
sol2.retcode, sol2.stats)
233+
elseif idx == 3
234+
SciMLBase.build_solution(prob, alg, sol3.u, sol3.resid;
235+
sol3.retcode, sol3.stats)
236+
elseif idx == 4
237+
SciMLBase.build_solution(prob, alg, sol4.u, sol4.resid;
238+
sol4.retcode, sol4.stats)
239+
else
240+
error("Unreachable reached, 박정석")
241+
end
242+
243+
end
244+
193245
## General shared polyalg functions
194246

195247
function perform_step!(cache::Union{RobustMultiNewtonCache, FastShortcutNonlinearPolyalgCache})

test/polyalgs.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,38 @@
1-
using NonlinearSolve
1+
using NonlinearSolve, Test
22

33
f(u, p) = u .* u .- 2
44
u0 = [1.0, 1.0]
55
probN = NonlinearProblem(f, u0)
66
@time solver = solve(probN, abstol = 1e-9)
77
@time solver = solve(probN, RobustMultiNewton(), abstol = 1e-9)
8-
@time solver = solve(probN, FastShortcutNonlinearPolyalg(), abstol = 1e-9)
8+
@time solver = solve(probN, FastShortcutNonlinearPolyalg(), abstol = 1e-9)
9+
10+
# https://github.com/SciML/NonlinearSolve.jl/issues/153
11+
12+
function f(du, u, p)
13+
s1, s1s2, s2 = u
14+
k1, c1, Δt = p
15+
16+
du[1] = -0.25 * c1 * k1 * s1 * s2
17+
du[2] = 0.25 * c1 * k1 * s1 * s2
18+
du[3] = -0.25 * c1 * k1 * s1 * s2
19+
end
20+
21+
prob = NonlinearProblem(f, [2.0,2.0,2.0], [1.0, 2.0, 2.5])
22+
sol = solve(prob)
23+
@test SciMLBase.successful_retcode(sol)
24+
25+
# https://github.com/SciML/NonlinearSolve.jl/issues/187
26+
27+
ff(u, p) = 0.5/1.5*log.(u./(1.0.-u)) .- 2.0*u .+1.0
28+
29+
uspan = (0.02, 0.1)
30+
prob = IntervalNonlinearProblem(ff, uspan)
31+
sol = solve(prob)
32+
@test SciMLBase.successful_retcode(sol)
33+
34+
u0 = 0.06
35+
p = 2.0
36+
prob = NonlinearProblem(ff, u0, p)
37+
solver = solve(prob)
38+
@test SciMLBase.successful_retcode(sol)

0 commit comments

Comments
 (0)