Skip to content

Commit 61e97a8

Browse files
committed
refactor: Remove unnecessary snippet
1 parent 70e8eff commit 61e97a8

File tree

4 files changed

+343
-546
lines changed

4 files changed

+343
-546
lines changed

lib/NonlinearSolveBase/src/polyalg.jl

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,78 @@ function SciMLBase.__init(
121121
)
122122
end
123123

124+
@generated function CommonSolve.solve!(cache::NonlinearSolvePolyAlgorithmCache{Val{N}}) where {N}
125+
calls = [quote
126+
1 cache.current $(N) || error("Current choices shouldn't get here!")
127+
end]
128+
129+
cache_syms = [gensym("cache") for i in 1:N]
130+
sol_syms = [gensym("sol") for i in 1:N]
131+
u_result_syms = [gensym("u_result") for i in 1:N]
132+
133+
for i in 1:N
134+
push!(calls,
135+
quote
136+
$(cache_syms[i]) = cache.caches[$(i)]
137+
if $(i) == cache.current
138+
cache.alias_u0 && copyto!(cache.u0_aliased, cache.u0)
139+
$(sol_syms[i]) = CommonSolve.solve!($(cache_syms[i]))
140+
if SciMLBase.successful_retcode($(sol_syms[i]))
141+
stats = $(sol_syms[i]).stats
142+
if cache.alias_u0
143+
copyto!(cache.u0, $(sol_syms[i]).u)
144+
$(u_result_syms[i]) = cache.u0
145+
else
146+
$(u_result_syms[i]) = $(sol_syms[i]).u
147+
end
148+
fu = NonlinearSolveBase.get_fu($(cache_syms[i]))
149+
return build_solution_less_specialize(
150+
cache.prob, cache.alg, $(u_result_syms[i]), fu;
151+
retcode = $(sol_syms[i]).retcode, stats,
152+
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace
153+
)
154+
elseif cache.alias_u0
155+
# For safety we need to maintain a copy of the solution
156+
$(u_result_syms[i]) = copy($(sol_syms[i]).u)
157+
end
158+
cache.current = $(i + 1)
159+
end
160+
end)
161+
end
162+
163+
resids = map(Base.Fix2(Symbol, :resid), cache_syms)
164+
for (sym, resid) in zip(cache_syms, resids)
165+
push!(calls, :($(resid) = @isdefined($(sym)) ? $(sym).resid : nothing))
166+
end
167+
push!(calls, quote
168+
fus = tuple($(Tuple(resids)...))
169+
minfu, idx = findmin_caches(cache.prob, fus)
170+
end)
171+
for i in 1:N
172+
push!(calls,
173+
quote
174+
if idx == $(i)
175+
u = cache.alias_u0 ? $(u_result_syms[i]) :
176+
NonlinearSolveBase.get_u(cache.caches[$(i)])
177+
end
178+
end)
179+
end
180+
push!(calls,
181+
quote
182+
retcode = cache.caches[idx].retcode
183+
if cache.alias_u0
184+
copyto!(cache.u0, u)
185+
u = cache.u0
186+
end
187+
return build_solution_less_specialize(
188+
cache.prob, cache.alg, u, fus[idx];
189+
retcode, cache.stats, cache.caches[idx].trace
190+
)
191+
end)
192+
193+
return Expr(:block, calls...)
194+
end
195+
124196
@generated function InternalAPI.step!(
125197
cache::NonlinearSolvePolyAlgorithmCache{Val{N}}, args...; kwargs...
126198
) where {N}
@@ -160,6 +232,92 @@ end
160232
return Expr(:block, calls...)
161233
end
162234

235+
@generated function SciMLBase.__solve(
236+
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...;
237+
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs...
238+
) where {N}
239+
sol_syms = [gensym("sol") for _ in 1:N]
240+
prob_syms = [gensym("prob") for _ in 1:N]
241+
u_result_syms = [gensym("u_result") for _ in 1:N]
242+
calls = [quote
243+
current = alg.start_index
244+
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
245+
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
246+
immutable (checked using `ArrayInterface.ismutable`)."
247+
alias_u0 = false # If immutable don't care about aliasing
248+
end
249+
u0 = prob.u0
250+
u0_aliased = alias_u0 ? zero(u0) : u0
251+
end]
252+
for i in 1:N
253+
cur_sol = sol_syms[i]
254+
push!(calls,
255+
quote
256+
if current == $(i)
257+
if alias_u0
258+
copyto!(u0_aliased, u0)
259+
$(prob_syms[i]) = SciMLBase.remake(prob; u0 = u0_aliased)
260+
else
261+
$(prob_syms[i]) = prob
262+
end
263+
$(cur_sol) = SciMLBase.__solve(
264+
$(prob_syms[i]), alg.algs[$(i)], args...;
265+
stats, alias_u0, verbose, kwargs...
266+
)
267+
if SciMLBase.successful_retcode($(cur_sol))
268+
if alias_u0
269+
copyto!(u0, $(cur_sol).u)
270+
$(u_result_syms[i]) = u0
271+
else
272+
$(u_result_syms[i]) = $(cur_sol).u
273+
end
274+
return build_solution_less_specialize(
275+
prob, alg, $(u_result_syms[i]), $(cur_sol).resid;
276+
$(cur_sol).retcode, $(cur_sol).stats,
277+
$(cur_sol).trace, original = $(cur_sol)
278+
)
279+
elseif alias_u0
280+
# For safety we need to maintain a copy of the solution
281+
$(u_result_syms[i]) = copy($(cur_sol).u)
282+
end
283+
current = $(i + 1)
284+
end
285+
end)
286+
end
287+
288+
resids = map(Base.Fix2(Symbol, :resid), sol_syms)
289+
for (sym, resid) in zip(sol_syms, resids)
290+
push!(calls, :($(resid) = @isdefined($(sym)) ? $(sym).resid : nothing))
291+
end
292+
293+
push!(calls, quote
294+
resids = tuple($(Tuple(resids)...))
295+
minfu, idx = findmin_resids(prob, resids)
296+
end)
297+
298+
for i in 1:N
299+
push!(calls,
300+
quote
301+
if idx == $(i)
302+
if alias_u0
303+
copyto!(u0, $(u_result_syms[i]))
304+
$(u_result_syms[i]) = u0
305+
else
306+
$(u_result_syms[i]) = $(sol_syms[i]).u
307+
end
308+
return build_solution_less_specialize(
309+
prob, alg, $(u_result_syms[i]), $(sol_syms[i]).resid;
310+
$(sol_syms[i]).retcode, $(sol_syms[i]).stats,
311+
$(sol_syms[i]).trace, original = $(sol_syms[i])
312+
)
313+
end
314+
end)
315+
end
316+
push!(calls, :(error("Current choices shouldn't get here!")))
317+
318+
return Expr(:block, calls...)
319+
end
320+
163321
# Original is often determined on runtime information especially for PolyAlgorithms so it
164322
# is best to never specialize on that
165323
function build_solution_less_specialize(

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ using SimpleNonlinearSolve: SimpleNonlinearSolve
4747

4848
const SII = SymbolicIndexingInterface
4949

50-
include("polyalg.jl")
50+
include("poly_algs.jl")
5151
include("extension_algs.jl")
5252

5353
include("default.jl")

src/poly_algs.jl

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
RobustMultiNewton(
3+
::Type{T} = Float64;
4+
concrete_jac = nothing,
5+
linsolve = nothing,
6+
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
7+
)
8+
9+
A polyalgorithm focused on robustness. It uses a mixture of Newton methods with different
10+
globalizing techniques (trust region updates, line searches, etc.) in order to find a
11+
method that is able to adequately solve the minimization problem.
12+
13+
Basically, if this algorithm fails, then "most" good ways of solving your problem fail and
14+
you may need to think about reformulating the model (either there is an issue with the model,
15+
or more precision / more stable linear solver choice is required).
16+
17+
### Arguments
18+
19+
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
20+
are compatible with the problem type. Defaults to `Float64`.
21+
"""
22+
function RobustMultiNewton(
23+
::Type{T} = Float64;
24+
concrete_jac = nothing,
25+
linsolve = nothing,
26+
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
27+
) where {T}
28+
common_kwargs = (; concrete_jac, linsolve, autodiff, vjp_autodiff, jvp_autodiff)
29+
if T <: Complex # Let's atleast have something here for complex numbers
30+
algs = (
31+
NewtonRaphson(; common_kwargs...),
32+
)
33+
else
34+
algs = (
35+
TrustRegion(; common_kwargs...),
36+
TrustRegion(; common_kwargs..., radius_update_scheme = RUS.Bastin),
37+
NewtonRaphson(; common_kwargs...),
38+
NewtonRaphson(; common_kwargs..., linesearch = BackTracking()),
39+
TrustRegion(; common_kwargs..., radius_update_scheme = RUS.NLsolve),
40+
TrustRegion(; common_kwargs..., radius_update_scheme = RUS.Fan)
41+
)
42+
end
43+
return NonlinearSolvePolyAlgorithm(algs)
44+
end
45+
46+
"""
47+
FastShortcutNonlinearPolyalg(
48+
::Type{T} = Float64;
49+
concrete_jac = nothing,
50+
linsolve = nothing,
51+
must_use_jacobian::Val = Val(false),
52+
prefer_simplenonlinearsolve::Val = Val(false),
53+
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing,
54+
u0_len::Union{Int, Nothing} = nothing
55+
) where {T}
56+
57+
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
58+
for more performance and then tries more robust techniques if the faster ones fail.
59+
60+
### Arguments
61+
62+
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
63+
are compatible with the problem type. Defaults to `Float64`.
64+
65+
### Keyword Arguments
66+
67+
- `u0_len`: The length of the initial guess. If this is `nothing`, then the length of the
68+
initial guess is not checked. If this is an integer and it is less than `25`, we use
69+
jacobian based methods.
70+
"""
71+
function FastShortcutNonlinearPolyalg(
72+
::Type{T} = Float64;
73+
concrete_jac = nothing,
74+
linsolve = nothing,
75+
must_use_jacobian::Val = Val(false),
76+
prefer_simplenonlinearsolve::Val = Val(false),
77+
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing,
78+
u0_len::Union{Int, Nothing} = nothing
79+
) where {T}
80+
start_index = 1
81+
common_kwargs = (; concrete_jac, linsolve, autodiff, vjp_autodiff, jvp_autodiff)
82+
if must_use_jacobian isa Val{true}
83+
if T <: Complex
84+
algs = (NewtonRaphson(; common_kwargs...),)
85+
else
86+
algs = (
87+
NewtonRaphson(; common_kwargs...),
88+
NewtonRaphson(; common_kwargs..., linesearch = BackTracking()),
89+
TrustRegion(; common_kwargs...),
90+
TrustRegion(; common_kwargs..., radius_update_scheme = RUS.Bastin)
91+
)
92+
end
93+
else
94+
# SimpleNewtonRaphson and SimpleTrustRegion are not robust to singular Jacobians
95+
# and thus are not included in the polyalgorithm
96+
if prefer_simplenonlinearsolve isa Val{true}
97+
if T <: Complex
98+
algs = (
99+
SimpleBroyden(),
100+
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
101+
SimpleKlement(),
102+
NewtonRaphson(; common_kwargs...)
103+
)
104+
else
105+
start_index = u0_len !== nothing ? (u0_len 25 ? 4 : 1) : 1
106+
algs = (
107+
SimpleBroyden(),
108+
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
109+
SimpleKlement(),
110+
NewtonRaphson(; common_kwargs...),
111+
NewtonRaphson(; common_kwargs..., linesearch = BackTracking()),
112+
TrustRegion(; common_kwargs...),
113+
TrustRegion(; common_kwargs..., radius_update_scheme = RUS.Bastin)
114+
)
115+
end
116+
else
117+
if T <: Complex
118+
algs = (
119+
Broyden(; autodiff),
120+
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
121+
Klement(; linsolve, autodiff),
122+
NewtonRaphson(; common_kwargs...)
123+
)
124+
else
125+
# TODO: This number requires a bit rigorous testing
126+
start_index = u0_len !== nothing ? (u0_len 25 ? 4 : 1) : 1
127+
algs = (
128+
Broyden(; autodiff),
129+
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
130+
Klement(; linsolve, autodiff),
131+
NewtonRaphson(; common_kwargs...),
132+
NewtonRaphson(; common_kwargs..., linesearch = BackTracking()),
133+
TrustRegion(; common_kwargs...),
134+
TrustRegion(; common_kwargs..., radius_update_scheme = RUS.Bastin)
135+
)
136+
end
137+
end
138+
end
139+
return NonlinearSolvePolyAlgorithm(algs; start_index)
140+
end
141+
142+
"""
143+
FastShortcutNLLSPolyalg(
144+
::Type{T} = Float64;
145+
concrete_jac = nothing,
146+
linsolve = nothing,
147+
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
148+
)
149+
150+
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
151+
for more performance and then tries more robust techniques if the faster ones fail.
152+
153+
### Arguments
154+
155+
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
156+
are compatible with the problem type. Defaults to `Float64`.
157+
"""
158+
function FastShortcutNLLSPolyalg(
159+
::Type{T} = Float64;
160+
concrete_jac = nothing,
161+
linsolve = nothing,
162+
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
163+
) where {T}
164+
common_kwargs = (; linsolve, autodiff, vjp_autodiff, jvp_autodiff)
165+
if T <: Complex
166+
algs = (
167+
GaussNewton(; common_kwargs..., concrete_jac),
168+
LevenbergMarquardt(; common_kwargs..., disable_geodesic = Val(true)),
169+
LevenbergMarquardt(; common_kwargs...)
170+
)
171+
else
172+
algs = (
173+
GaussNewton(; common_kwargs..., concrete_jac),
174+
LevenbergMarquardt(; common_kwargs..., disable_geodesic = Val(true)),
175+
TrustRegion(; common_kwargs..., concrete_jac),
176+
GaussNewton(; common_kwargs..., linesearch = BackTracking(), concrete_jac),
177+
TrustRegion(;
178+
common_kwargs..., radius_update_scheme = RUS.Bastin, concrete_jac
179+
),
180+
LevenbergMarquardt(; common_kwargs...)
181+
)
182+
end
183+
return NonlinearSolvePolyAlgorithm(algs)
184+
end

0 commit comments

Comments
 (0)