Skip to content

Commit 8f27d1f

Browse files
committed
feat: SimpleTrustRegion implementation
1 parent ed995cd commit 8f27d1f

File tree

3 files changed

+217
-25
lines changed

3 files changed

+217
-25
lines changed

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SimpleNonlinearSolve
33
using CommonSolve: CommonSolve, solve
44
using ConcreteStructs: @concrete
55
using FastClosures: @closure
6+
using LinearAlgebra: dot
67
using MaybeInplace: @bb
78
using PrecompileTools: @compile_workload, @setup_workload
89
using Reexport: @reexport
@@ -17,7 +18,8 @@ using FiniteDiff: FiniteDiff
1718
using ForwardDiff: ForwardDiff
1819

1920
using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
20-
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, get_tolerance
21+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, get_tolerance,
22+
L2_NORM
2123

2224
const DI = DifferentiationInterface
2325

@@ -78,7 +80,8 @@ function solve_adjoint_internal end
7880

7981
algs = [
8082
SimpleKlement(),
81-
SimpleNewtonRaphson()
83+
SimpleNewtonRaphson(),
84+
SimpleTrustRegion()
8285
]
8386
algs_no_iip = []
8487

@@ -98,6 +101,6 @@ export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
98101
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
99102

100103
export SimpleKlement
101-
export SimpleGaussNewton, SimpleNewtonRaphson
104+
export SimpleGaussNewton, SimpleNewtonRaphson, SimpleTrustRegion
102105

103106
end
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,212 @@
11

2+
# """
3+
# SimpleTrustRegion(; autodiff = AutoForwardDiff(), max_trust_radius = 0.0,
4+
# initial_trust_radius = 0.0, step_threshold = nothing,
5+
# shrink_threshold = nothing, expand_threshold = nothing,
6+
# shrink_factor = 0.25, expand_factor = 2.0, max_shrink_times::Int = 32,
7+
# nlsolve_update_rule = Val(false))
8+
9+
# A low-overhead implementation of a trust-region solver. This method is non-allocating on
10+
# scalar and static array problems.
11+
12+
# ### Keyword Arguments
13+
14+
# - `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
15+
# automatic backend selection). Valid choices include jacobian backends from
16+
# `DifferentiationInterface.jl`.
17+
# - `max_trust_radius`: the maximum radius of the trust region. Defaults to
18+
# `max(norm(f(u0)), maximum(u0) - minimum(u0))`.
19+
# - `initial_trust_radius`: the initial trust region radius. Defaults to
20+
# `max_trust_radius / 11`.
21+
# - `step_threshold`: the threshold for taking a step. In every iteration, the threshold is
22+
# compared with a value `r`, which is the actual reduction in the objective function divided
23+
# by the predicted reduction. If `step_threshold > r` the model is not a good approximation,
24+
# and the step is rejected. Defaults to `0.1`. For more details, see
25+
# [Rahpeymaii, F.](https://link.springer.com/article/10.1007/s40096-020-00339-4)
26+
# - `shrink_threshold`: the threshold for shrinking the trust region radius. In every
27+
# iteration, the threshold is compared with a value `r` which is the actual reduction in the
28+
# objective function divided by the predicted reduction. If `shrink_threshold > r` the trust
29+
# region radius is shrunk by `shrink_factor`. Defaults to `0.25`. For more details, see
30+
# [Rahpeymaii, F.](https://link.springer.com/article/10.1007/s40096-020-00339-4)
31+
# - `expand_threshold`: the threshold for expanding the trust region radius. If a step is
32+
# taken, i.e `step_threshold < r` (with `r` defined in `shrink_threshold`), a check is also
33+
# made to see if `expand_threshold < r`. If that is true, the trust region radius is
34+
# expanded by `expand_factor`. Defaults to `0.75`.
35+
# - `shrink_factor`: the factor to shrink the trust region radius with if
36+
# `shrink_threshold > r` (with `r` defined in `shrink_threshold`). Defaults to `0.25`.
37+
# - `expand_factor`: the factor to expand the trust region radius with if
38+
# `expand_threshold < r` (with `r` defined in `shrink_threshold`). Defaults to `2.0`.
39+
# - `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
40+
# row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
41+
# - `nlsolve_update_rule`: If set to `Val(true)`, updates the trust region radius using the
42+
# update rule from NLSolve.jl. Defaults to `Val(false)`. If set to `Val(true)`, few of the
43+
# radius update parameters -- `step_threshold = 0.05`, `expand_threshold = 0.9`, and
44+
# `shrink_factor = 0.5` -- have different defaults.
45+
# """
46+
@kwdef @concrete struct SimpleTrustRegion <: AbstractSimpleNonlinearSolveAlgorithm
47+
autodiff = nothing
48+
max_trust_radius = 0.0
49+
initial_trust_radius = 0.0
50+
step_threshold = 0.0001
51+
shrink_threshold = nothing
52+
expand_threshold = nothing
53+
shrink_factor = nothing
54+
expand_factor = 2.0
55+
max_shrink_times::Int = 32
56+
nlsolve_update_rule = Val(false)
57+
end
58+
59+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegion,
60+
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
61+
alias_u0 = false, termination_condition = nothing, kwargs...)
62+
x = Utils.maybe_unaliased(prob.u0, alias_u0)
63+
T = eltype(x)
64+
Δₘₐₓ = T(alg.max_trust_radius)
65+
Δ = T(alg.initial_trust_radius)
66+
η₁ = T(alg.step_threshold)
67+
68+
if alg.shrink_threshold === nothing
69+
η₂ = T(ifelse(SciMLBase._unwrap_val(alg.nlsolve_update_rule), 0.05, 0.25))
70+
else
71+
η₂ = T(alg.shrink_threshold)
72+
end
73+
74+
if alg.expand_threshold === nothing
75+
η₃ = T(ifelse(SciMLBase._unwrap_val(alg.nlsolve_update_rule), 0.9, 0.75))
76+
else
77+
η₃ = T(alg.expand_threshold)
78+
end
79+
80+
if alg.shrink_factor === nothing
81+
t₁ = T(ifelse(SciMLBase._unwrap_val(alg.nlsolve_update_rule), 0.5, 0.25))
82+
else
83+
t₁ = T(alg.shrink_factor)
84+
end
85+
86+
t₂ = T(alg.expand_factor)
87+
max_shrink_times = alg.max_shrink_times
88+
89+
autodiff = SciMLBase.has_jac(prob.f) ? alg.autodiff :
90+
NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
91+
92+
fx = Utils.get_fx(prob, x)
93+
fx = Utils.eval_f(prob, fx, x)
94+
norm_fx = L2_NORM(fx)
95+
96+
@bb xo = copy(x)
97+
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) :
98+
nothing
99+
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
100+
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
101+
102+
abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
103+
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))
104+
105+
# Set default trust region radius if not specified by user.
106+
iszero(Δₘₐₓ) && (Δₘₐₓ = max(L2_NORM(fx), maximum(x) - minimum(x)))
107+
if iszero(Δ)
108+
if SciMLBase._unwrap_val(alg.nlsolve_update_rule)
109+
norm_x = L2_NORM(x)
110+
Δ = T(ifelse(norm_x > 0, norm_x, 1))
111+
else
112+
Δ = T(Δₘₐₓ / 11)
113+
end
114+
end
115+
116+
fₖ = 0.5 * norm_fx^2
117+
H = transpose(J) * J
118+
g = Utils.restructure(x, J' * Utils.safe_vec(fx))
119+
shrink_counter = 0
120+
121+
@bb δsd = copy(x)
122+
@bb δN_δsd = copy(x)
123+
@bb δN = copy(x)
124+
@bb= copy(x)
125+
dogleg_cache = (; δsd, δN_δsd, δN)
126+
127+
for _ in 1:maxiters
128+
# Solve the trust region subproblem.
129+
δ = dogleg_method!!(dogleg_cache, J, fx, g, Δ)
130+
@bb @. x = xo + δ
131+
132+
fx = Utils.eval_f(prob, fx, x)
133+
134+
fₖ₊₁ = L2_NORM(fx)^2 / T(2)
135+
136+
# Compute the ratio of the actual to predicted reduction.
137+
@bb= H × vec(δ)
138+
r = (fₖ₊₁ - fₖ) / (dot(δ, g) + (dot(δ, Hδ) / T(2)))
139+
140+
# Update the trust region radius.
141+
if r η₂
142+
shrink_counter = 0
143+
else
144+
Δ = t₁ * Δ
145+
shrink_counter += 1
146+
shrink_counter > max_shrink_times && return SciMLBase.build_solution(
147+
prob, alg, x, fx; retcode = ReturnCode.ShrinkThresholdExceeded)
148+
end
149+
150+
if r η₁
151+
# Termination Checks
152+
solved, retcode, fx_sol, x_sol = Utils.check_termination(
153+
tc_cache, fx, x, xo, prob)
154+
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
155+
156+
# Take the step.
157+
@bb copyto!(xo, x)
158+
159+
J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache)
160+
fx = Utils.eval_f(prob, fx, x)
161+
162+
# Update the trust region radius.
163+
if !SciMLBase._unwrap_val(alg.nlsolve_update_rule) && r > η₃
164+
Δ = min(t₂ * Δ, Δₘₐₓ)
165+
end
166+
fₖ = fₖ₊₁
167+
168+
@bb H = transpose(J) × J
169+
@bb g = transpose(J) × vec(fx)
170+
end
171+
172+
if SciMLBase._unwrap_val(alg.nlsolve_update_rule)
173+
if r > η₃
174+
Δ = t₂ * L2_NORM(δ)
175+
elseif r > 0.5
176+
Δ = max(Δ, t₂ * L2_NORM(δ))
177+
end
178+
end
179+
end
180+
181+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
182+
end
183+
184+
function dogleg_method!!(cache, J, f, g, Δ)
185+
(; δsd, δN_δsd, δN) = cache
186+
187+
# Compute the Newton step
188+
@bb δN .= Utils.restructure(δN, J \ Utils.safe_vec(f))
189+
@bb δN .*= -1
190+
# Test if the full step is within the trust region
191+
(L2_NORM(δN) Δ) && return δN
192+
193+
# Calcualte Cauchy point, optimum along the steepest descent direction
194+
@bb δsd .= g
195+
@bb @. δsd *= -1
196+
norm_δsd = L2_NORM(δsd)
197+
198+
if (norm_δsd Δ)
199+
@bb @. δsd *= Δ / norm_δsd
200+
return δsd
201+
end
202+
203+
# Find the intersection point on the boundary
204+
@bb @. δN_δsd = δN - δsd
205+
dot_δN_δsd = dot(δN_δsd, δN_δsd)
206+
dot_δsd_δN_δsd = dot(δsd, δN_δsd)
207+
dot_δsd = dot(δsd, δsd)
208+
fact = dot_δsd_δN_δsd^2 - dot_δN_δsd * (dot_δsd - Δ^2)
209+
tau = (-dot_δsd_δN_δsd + sqrt(fact)) / dot_δN_δsd
210+
@bb @. δsd += tau * δN_δsd
211+
return δsd
212+
end

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,6 @@ function maybe_unaliased(x::T, alias::Bool) where {T <: AbstractArray}
2828
return copy(x)
2929
end
3030

31-
function get_concrete_autodiff(_, ad::AbstractADType)
32-
DI.check_available(ad) && return ad
33-
error("AD Backend $(ad) is not available. This could be because you haven't loaded the \
34-
actual backend (See [Differentiation Interface Docs](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/) \
35-
for more details) or the backend might not be supported by DifferentiationInterface.jl.")
36-
end
37-
function get_concrete_autodiff(
38-
prob, ad::Union{AutoForwardDiff{nothing}, AutoPolyesterForwardDiff{nothing}})
39-
return get_concrete_autodiff(prob,
40-
ArrayInterface.parameterless_type(ad)(;
41-
chunksize = pickchunksize(length(prob.u0)), ad.tag))
42-
end
43-
function get_concrete_autodiff(prob, ::Nothing)
44-
if can_dual(eltype(prob.u0)) && DI.check_available(AutoForwardDiff())
45-
return AutoForwardDiff(; chunksize = pickchunksize(length(prob.u0)))
46-
end
47-
DI.check_available(AutoFiniteDiff()) && return AutoFiniteDiff()
48-
error("Default AD backends are not available. Please load either FiniteDiff or \
49-
ForwardDiff for default AD selection to work. Else provide a specific AD \
50-
backend (instead of `nothing`) to the solver.")
51-
end
52-
5331
# NOTE: This doesn't initialize the `f(x)` but just returns a buffer of the same size
5432
function get_fx(prob::NonlinearLeastSquaresProblem, x)
5533
if SciMLBase.isinplace(prob) && prob.f.resid_prototype === nothing

0 commit comments

Comments
 (0)