Skip to content

Commit 66a0a7d

Browse files
Add polyalgorithms and a default method
This adds two polyalgorithms, one focused on ultra-robustness and another focused on mixing robustness and performance, and sets the default algorithm to the balanced one. There's still a lot of improvements that can be done to this, but I think this is good enough to at least have a lot of utility, and the defaults can probably specialize on StaticArrays and stuff like that to further be optimized, but that's fine.
1 parent 364d7f4 commit 66a0a7d

File tree

4 files changed

+334
-1
lines changed

4 files changed

+334
-1
lines changed

src/NonlinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ include("levenberg.jl")
6868
include("gaussnewton.jl")
6969
include("jacobian.jl")
7070
include("ad.jl")
71+
include("default.jl")
7172

7273
import PrecompileTools
7374

@@ -95,6 +96,7 @@ export RadiusUpdateSchemes
9596

9697
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
9798
export LeastSquaresOptimJL, FastLevenbergMarquardtJL
99+
export RobustMultiNewton, FastShortcutNonlinearPolyalg
98100

99101
export LineSearch
100102

src/default.jl

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
"""
2+
RobustMultiNewton(; concrete_jac = nothing, linsolve = nothing,
3+
precs = DEFAULT_PRECS, adkwargs...)
4+
5+
A polyalgorithm focused on robustness. It uses a mixture of Newton methods with different
6+
globalizing techniques (trust region updates, line searches, etc.) in order to find a
7+
method that is able to adequately solve the minimization problem.
8+
9+
Basically, if this algorithm fails, then "most" good ways of solving your problem fail and
10+
you may need to think about reformulating the model (either there is an issue with the model,
11+
or more precision / more stable linear solver choice is required).
12+
13+
### Keyword Arguments
14+
15+
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
16+
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
17+
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
18+
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used,
19+
then the Jacobian will not be constructed and instead direct Jacobian-vector products
20+
`J*v` are computed using forward-mode automatic differentiation or finite differencing
21+
tricks (without ever constructing the Jacobian). However, if the Jacobian is still needed,
22+
for example for a preconditioner, `concrete_jac = true` can be passed in order to force
23+
the construction of the Jacobian.
24+
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
25+
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
26+
LinearSolve.jl default algorithm choice. For more information on available algorithm
27+
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
28+
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
29+
preconditioners. For more information on specifying preconditioners for LinearSolve
30+
algorithms, consult the
31+
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
32+
"""
33+
@concrete struct RobustMultiNewton{CJ} <: AbstractNewtonAlgorithm{CJ, Nothing}
34+
adkwargs
35+
linsolve
36+
precs
37+
end
38+
39+
function RobustMultiNewton(; concrete_jac = nothing, linsolve = nothing,
40+
precs = DEFAULT_PRECS, adkwargs...)
41+
42+
return RobustMultiNewton{_unwrap_val(concrete_jac)}(adkwargs, linsolve, precs)
43+
end
44+
45+
@concrete mutable struct RobustMultiNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
46+
caches
47+
alg
48+
current::Int
49+
end
50+
51+
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::RobustMultiNewton, args...;
52+
kwargs...) where {uType, iip}
53+
54+
adkwargs = alg.adkwargs
55+
linsolve = alg.linsolve
56+
precs = alg.precs
57+
58+
RobustMultiNewtonCache{iip}((
59+
SciMLBase.__init(prob, TrustRegion(;linsolve, precs, adkwargs...), args...; kwargs...),
60+
SciMLBase.__init(prob, TrustRegion(;linsolve, precs, radius_update_scheme = RadiusUpdateSchemes.Bastin, adkwargs...), args...; kwargs...),
61+
SciMLBase.__init(prob, NewtonRaphson(;linsolve, precs, linesearch=BackTracking(), adkwargs...), args...; kwargs...),
62+
SciMLBase.__init(prob, TrustRegion(;linsolve, precs, radius_update_scheme = RadiusUpdateSchemes.Fan, adkwargs...), args...; kwargs...),
63+
), alg, 1
64+
)
65+
end
66+
67+
"""
68+
FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothing,
69+
precs = DEFAULT_PRECS, adkwargs...)
70+
71+
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
72+
for more performance and then tries more robust techniques if the faster ones fail.
73+
74+
### Keyword Arguments
75+
76+
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
77+
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
78+
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
79+
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used,
80+
then the Jacobian will not be constructed and instead direct Jacobian-vector products
81+
`J*v` are computed using forward-mode automatic differentiation or finite differencing
82+
tricks (without ever constructing the Jacobian). However, if the Jacobian is still needed,
83+
for example for a preconditioner, `concrete_jac = true` can be passed in order to force
84+
the construction of the Jacobian.
85+
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
86+
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
87+
LinearSolve.jl default algorithm choice. For more information on available algorithm
88+
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
89+
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
90+
preconditioners. For more information on specifying preconditioners for LinearSolve
91+
algorithms, consult the
92+
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
93+
"""
94+
@concrete struct FastShortcutNonlinearPolyalg{CJ} <: AbstractNewtonAlgorithm{CJ, Nothing}
95+
adkwargs
96+
linsolve
97+
precs
98+
end
99+
100+
function FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothing,
101+
precs = DEFAULT_PRECS, adkwargs...)
102+
103+
return FastShortcutNonlinearPolyalg{_unwrap_val(concrete_jac)}(adkwargs, linsolve, precs)
104+
end
105+
106+
@concrete mutable struct FastShortcutNonlinearPolyalgCache{iip} <: AbstractNonlinearSolveCache{iip}
107+
caches
108+
alg
109+
current::Int
110+
end
111+
112+
function FastShortcutNonlinearPolyalgCache(; concrete_jac = nothing, linsolve = nothing,
113+
precs = DEFAULT_PRECS, adkwargs...)
114+
115+
return FastShortcutNonlinearPolyalgCache{_unwrap_val(concrete_jac)}(adkwargs, linsolve, precs)
116+
end
117+
118+
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::FastShortcutNonlinearPolyalg, args...;
119+
kwargs...) where {uType, iip}
120+
121+
adkwargs = alg.adkwargs
122+
linsolve = alg.linsolve
123+
precs = alg.precs
124+
125+
FastShortcutNonlinearPolyalgCache{iip}((
126+
#SciMLBase.__init(prob, Klement(), args...; kwargs...),
127+
#SciMLBase.__init(prob, Broyden(), args...; kwargs...),
128+
SciMLBase.__init(prob, NewtonRaphson(;linsolve, precs, adkwargs...), args...; kwargs...),
129+
SciMLBase.__init(prob, NewtonRaphson(;linsolve, precs, linesearch=BackTracking(), adkwargs...), args...; kwargs...),
130+
SciMLBase.__init(prob, TrustRegion(;linsolve, precs, adkwargs...), args...; kwargs...),
131+
SciMLBase.__init(prob, TrustRegion(;linsolve, precs, radius_update_scheme = RadiusUpdateSchemes.Bastin, adkwargs...), args...; kwargs...),
132+
), alg, 1
133+
)
134+
end
135+
136+
function SciMLBase.__solve(prob::NonlinearProblem{uType, iip}, alg::FastShortcutNonlinearPolyalg, args...;
137+
kwargs...) where {uType, iip}
138+
139+
adkwargs = alg.adkwargs
140+
linsolve = alg.linsolve
141+
precs = alg.precs
142+
143+
sol1 = SciMLBase.__solve(prob, Klement(), args...; kwargs...)
144+
if SciMLBase.successful_retcode(sol1)
145+
return SciMLBase.build_solution(prob, alg, sol1.u, sol1.resid;
146+
sol1.retcode, sol1.stats)
147+
end
148+
149+
sol2 = SciMLBase.__solve(prob, Broyden(), args...; kwargs...)
150+
if SciMLBase.successful_retcode(sol2)
151+
return SciMLBase.build_solution(prob, alg, sol2.u, sol2.resid;
152+
sol2.retcode, sol2.stats)
153+
end
154+
155+
sol3 = SciMLBase.__solve(prob, NewtonRaphson(;linsolve, precs, adkwargs...), args...; kwargs...)
156+
if SciMLBase.successful_retcode(sol3)
157+
return SciMLBase.build_solution(prob, alg, sol3.u, sol3.resid;
158+
sol3.retcode, sol3.stats)
159+
end
160+
161+
sol4 = SciMLBase.__solve(prob, TrustRegion(;linsolve, precs, adkwargs...), args...; kwargs...)
162+
if SciMLBase.successful_retcode(sol4)
163+
return SciMLBase.build_solution(prob, alg, sol4.u, sol4.resid;
164+
sol4.retcode, sol4.stats)
165+
end
166+
167+
resids = (sol1.resid, sol2.resid, sol3.resid, sol4.resid)
168+
minfu, idx = findmin(DEFAULT_NORM, resids)
169+
170+
if idx == 1
171+
SciMLBase.build_solution(prob, alg, sol1.u, sol1.resid;
172+
sol1.retcode, sol1.stats)
173+
elseif idx == 2
174+
SciMLBase.build_solution(prob, alg, sol2.u, sol2.resid;
175+
sol2.retcode, sol2.stats)
176+
elseif idx == 3
177+
SciMLBase.build_solution(prob, alg, sol3.u, sol3.resid;
178+
sol3.retcode, sol3.stats)
179+
elseif idx == 4
180+
SciMLBase.build_solution(prob, alg, sol4.u, sol4.resid;
181+
sol4.retcode, sol4.stats)
182+
else
183+
error("Unreachable reached, 박정석")
184+
end
185+
186+
end
187+
188+
## General shared polyalg functions
189+
190+
function perform_step!(cache::Union{RobustMultiNewtonCache, FastShortcutNonlinearPolyalgCache})
191+
current = cache.current
192+
193+
while true
194+
if current == 1
195+
perform_step!(cache.caches[1])
196+
elseif current == 2
197+
perform_step!(cache.caches[2])
198+
elseif current == 3
199+
perform_step!(cache.caches[3])
200+
elseif current == 4
201+
perform_step!(cache.caches[4])
202+
else
203+
error("Current choices shouldn't get here!")
204+
end
205+
end
206+
207+
return nothing
208+
end
209+
210+
function SciMLBase.solve!(cache::Union{RobustMultiNewtonCache, FastShortcutNonlinearPolyalgCache})
211+
current = cache.current
212+
213+
while current < 5 && all(not_terminated, cache.caches)
214+
if current == 1
215+
perform_step!(cache.caches[1])
216+
!not_terminated(cache.caches[1]) && (cache.current += 1)
217+
elseif current == 2
218+
perform_step!(cache.caches[2])
219+
!not_terminated(cache.caches[2]) && (cache.current += 1)
220+
elseif current == 3
221+
perform_step!(cache.caches[3])
222+
!not_terminated(cache.caches[3]) && (cache.current += 1)
223+
elseif current == 4
224+
perform_step!(cache.caches[4])
225+
!not_terminated(cache.caches[4]) && (cache.current += 1)
226+
else
227+
error("Current choices shouldn't get here!")
228+
end
229+
230+
231+
232+
#cache.stats.nsteps += 1
233+
end
234+
235+
if current < 5
236+
stats = if current == 1
237+
cache.caches[1].stats
238+
elseif current == 2
239+
cache.caches[2].stats
240+
elseif current == 3
241+
cache.caches[3].stats
242+
elseif current == 4
243+
cache.caches[4].stats
244+
end
245+
246+
u = if current == 1
247+
cache.caches[1].u
248+
elseif current == 2
249+
cache.caches[2].u
250+
elseif current == 3
251+
cache.caches[3].u
252+
elseif current == 4
253+
cache.caches[4].u
254+
end
255+
256+
fu = if current == 1
257+
get_fu(cache.caches[1])
258+
elseif current == 2
259+
get_fu(cache.caches[2])
260+
elseif current == 3
261+
get_fu(cache.caches[3])
262+
elseif current == 4
263+
get_fu(cache.caches[4])
264+
end
265+
266+
retcode = if stats.nsteps == cache.caches[1].maxiters
267+
ReturnCode.MaxIters
268+
else
269+
ReturnCode.Success
270+
end
271+
272+
return SciMLBase.build_solution(cache.caches[1].prob, cache.alg, u, fu;
273+
retcode, stats)
274+
else
275+
retcode = ReturnCode.MaxIters
276+
277+
fus = (get_fu(cache.caches[1]), get_fu(cache.caches[2]), get_fu(cache.caches[3]), get_fu(cache.caches[4]))
278+
minfu, idx = findmin(cache.caches[1].internalnorm, fus)
279+
280+
stats = if idx == 1
281+
cache.caches[1].stats
282+
elseif idx == 2
283+
cache.caches[2].stats
284+
elseif idx == 3
285+
cache.caches[3].stats
286+
elseif idx == 4
287+
cache.caches[4].stats
288+
end
289+
290+
u = if idx == 1
291+
cache.caches[1].u
292+
elseif idx == 2
293+
cache.caches[2].u
294+
elseif idx == 3
295+
cache.caches[3].u
296+
elseif idx == 4
297+
cache.caches[4].u
298+
end
299+
300+
return SciMLBase.build_solution(cache.caches[1].prob, cache.alg, u, fu;
301+
retcode, stats)
302+
end
303+
end
304+
305+
function SciMLBase.reinit!(cache::Union{RobustMultiNewtonCache, FastShortcutNonlinearPolyalgCache}, args...; kwargs...)
306+
for c in cache.caches
307+
SciMLBase.reinit!(c, args...; kwargs...)
308+
end
309+
end
310+
311+
## Defaults
312+
313+
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::Nothing, args...;
314+
kwargs...) where {uType, iip}
315+
316+
SciMLBase.__init(prob, FastShortcutNonlinearPolyalg(), args...; kwargs...)
317+
end
318+
319+
function SciMLBase.__solve(prob::NonlinearProblem{uType, iip}, alg::Nothing, args...;
320+
kwargs...) where {uType, iip}
321+
322+
SciMLBase.__solve(prob, FastShortcutNonlinearPolyalg(), args...; kwargs...)
323+
end

test/polyalgs.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
using NonlinearSolve
2+
3+
f(u, p) = u .* u .- 2
4+
u0 = [1.0, 1.0]
5+
probN = NonlinearProblem(f, u0)
6+
@time solver = solve(probN, abstol = 1e-9)
7+
@time solver = solve(probN, RobustMultiNewton(), abstol = 1e-9)
8+
@time solver = solve(probN, FastShortcutNonlinearPolyalg(), abstol = 1e-9)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515
if GROUP == "All" || GROUP == "Core"
1616
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
1717
@time @safetestset "Sparsity Tests" include("sparse.jl")
18-
18+
@time @safetestset "Polyalgs" include("polyalgs.jl")
1919
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
2020
end
2121

0 commit comments

Comments
 (0)