Skip to content

Commit f18fe15

Browse files
committed
Start cleaning up TrustRegion
1 parent 031639f commit f18fe15

File tree

3 files changed

+68
-125
lines changed

3 files changed

+68
-125
lines changed

src/jacobian.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
209209
end
210210

211211
# Generic Handling of Krylov Methods for Normal Form Linear Solves
212+
# FIXME: Use MaybeInplace here for efficient matmuls
212213
function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
213214
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
214215
end

src/trustRegion.jl

Lines changed: 53 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
`RadiusUpdateSchemes`
2+
RadiusUpdateSchemes
33
44
`RadiusUpdateSchemes` is the standard enum interface for different types of radius update schemes
55
implemented in the Trust Region method. These schemes specify how the radius of the so-called trust region
@@ -16,7 +16,7 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
1616
"""
1717
@enumx RadiusUpdateSchemes begin
1818
"""
19-
`RadiusUpdateSchemes.Simple`
19+
RadiusUpdateSchemes.Simple
2020
2121
The simple or conventional radius update scheme. This scheme is chosen by default
2222
and follows the conventional approach to update the trust region radius, i.e. if the
@@ -26,21 +26,21 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
2626
Simple
2727

2828
"""
29-
`RadiusUpdateSchemes.NLsolve`
29+
RadiusUpdateSchemes.NLsolve
3030
3131
The same updating scheme as in NLsolve's (https://github.com/JuliaNLSolvers/NLsolve.jl) trust region dogleg implementation.
3232
"""
3333
NLsolve
3434

3535
"""
36-
`RadiusUpdateSchemes.NocedalWright`
36+
RadiusUpdateSchemes.NocedalWright
3737
3838
Trust region updating scheme as in Nocedal and Wright [see Alg 11.5, page 291].
3939
"""
4040
NocedalWright
4141

4242
"""
43-
`RadiusUpdateSchemes.Hei`
43+
RadiusUpdateSchemes.Hei
4444
4545
This scheme is proposed by [Hei, L.] (https://www.jstor.org/stable/43693061). The trust region radius
4646
depends on the size (norm) of the current step size. The hypothesis is to let the radius converge to zero
@@ -50,7 +50,7 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
5050
Hei
5151

5252
"""
53-
`RadiusUpdateSchemes.Yuan`
53+
RadiusUpdateSchemes.Yuan
5454
5555
This scheme is proposed by [Yuan, Y.] (https://www.researchgate.net/publication/249011466_A_new_trust_region_algorithm_with_trust_region_radius_converging_to_zero).
5656
Similar to Hei's scheme, the trust region is updated in a way so that it converges to zero, however here,
@@ -60,7 +60,7 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
6060
Yuan
6161

6262
"""
63-
`RadiusUpdateSchemes.Bastin`
63+
RadiusUpdateSchemes.Bastin
6464
6565
This scheme is proposed by [Bastin, et al.] (https://www.researchgate.net/publication/225100660_A_retrospective_trust-region_method_for_unconstrained_optimization).
6666
The scheme is called a retrospective update scheme as it uses the model function at the current
@@ -71,7 +71,7 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
7171
Bastin
7272

7373
"""
74-
`RadiusUpdateSchemes.Fan`
74+
RadiusUpdateSchemes.Fan
7575
7676
This scheme is proposed by [Fan, J.] (https://link.springer.com/article/10.1007/s10589-005-3078-8). It is very much similar to
7777
Hei's and Yuan's schemes as it lets the trust region radius depend on the current size (norm) of the objective (merit)
@@ -170,7 +170,7 @@ function set_ad(alg::TrustRegion{CJ}, ad) where {CJ}
170170
end
171171

172172
function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
173-
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple, #defaults to conventional radius update
173+
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple,
174174
max_trust_radius::Real = 0 // 1, initial_trust_radius::Real = 0 // 1,
175175
step_threshold::Real = 1 // 10000, shrink_threshold::Real = 1 // 4,
176176
expand_threshold::Real = 3 // 4, shrink_factor::Real = 1 // 4,
@@ -233,6 +233,7 @@ end
233233
trace
234234
end
235235

236+
# TODO: add J_cache
236237
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, args...;
237238
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
238239
termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;),
@@ -244,7 +245,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
244245
fu1 = evaluate_f(prob, u)
245246
fu_prev = zero(fu1)
246247

247-
loss = get_loss(fu1)
248+
loss = __get_trust_region_loss(fu1)
248249
uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip);
249250
linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false))
250251
g = _restructure(fu1, g)
@@ -350,92 +351,54 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
350351
p1, p2, p3, p4, ϵ, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
351352
end
352353

353-
function perform_step!(cache::TrustRegionCache{true})
354-
@unpack make_new_J, J, fu, f, u, p, u_gauss_newton, alg, linsolve = cache
354+
function perform_step!(cache::TrustRegionCache{iip}) where {iip}
355355
if cache.make_new_J
356-
jacobian!!(J, cache)
357-
__update_JᵀJ!(Val{true}(), cache, :H, J)
358-
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, _vec(fu))
356+
cache.J = jacobian!!(cache.J, cache)
357+
358+
__update_JᵀJ!(Val{iip}(), cache, :H, cache.J)
359+
__update_Jᵀf!(Val{iip}(), cache, :g, :H, cache.J, _vec(cache.fu))
359360
cache.stats.njacs += 1
360361

361362
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
362363
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
363-
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu),
364-
linu = _vec(u_gauss_newton), p = p, reltol = cache.abstol)
364+
linres = dolinsolve(cache.alg.precs, cache.linsolve, A = cache.J,
365+
b = _vec(cache.fu), linu = _vec(cache.u_gauss_newton), p = cache.p,
366+
reltol = cache.abstol)
365367
cache.linsolve = linres.cache
366-
@. cache.u_gauss_newton = -1 * u_gauss_newton
367-
end
368-
369-
# Compute dogleg step
370-
dogleg!(cache)
371-
372-
# Compute the potentially new u
373-
@. cache.u_tmp = u + cache.du
374-
f(cache.fu_new, cache.u_tmp, p)
375-
trust_region_step!(cache)
376-
cache.stats.nf += 1
377-
cache.stats.nsolve += 1
378-
cache.stats.nfactors += 1
379-
return nothing
380-
end
381-
382-
function perform_step!(cache::TrustRegionCache{false})
383-
@unpack make_new_J, fu, f, u, p = cache
384-
385-
if make_new_J
386-
J = jacobian!!(cache.J, cache)
387-
__update_JᵀJ!(Val{false}(), cache, :H, J)
388-
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, _vec(fu))
389-
cache.stats.njacs += 1
390-
391-
if cache.linsolve === nothing
392-
# Scalar
393-
cache.u_gauss_newton = -cache.H \ cache.g
394-
else
395-
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
396-
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
397-
linres = dolinsolve(cache.alg.precs, cache.linsolve, A = cache.J, b = _vec(fu),
398-
linu = _vec(cache.u_gauss_newton), p = p, reltol = cache.abstol)
399-
cache.linsolve = linres.cache
400-
@. cache.u_gauss_newton *= -1
401-
end
368+
cache.u_gauss_newton = _restructure(cache.u_gauss_newton, linres.u)
369+
@bb @. cache.u_gauss_newton *= -1
402370
end
403371

404-
# Compute the Newton step.
372+
# compute dogleg step
405373
dogleg!(cache)
406374

407-
# Compute the potentially new u
408-
cache.u_tmp = u + cache.du
409-
410-
cache.fu_new = f(cache.u_tmp, p)
375+
# compute the potentially new u
376+
@bb @. cache.u_cache_2 = cache.u + cache.du
377+
evaluate_f(cache, cache.u_tmp, cache.p, Val{:fu_cache_2}())
411378
trust_region_step!(cache)
412379
cache.stats.nf += 1
413380
cache.stats.nsolve += 1
414381
cache.stats.nfactors += 1
415382
return nothing
416383
end
417384

418-
function retrospective_step!(cache::TrustRegionCache)
419-
@unpack J, fu_prev, fu, u_prev, u = cache
420-
J = jacobian!!(deepcopy(J), cache)
421-
if J isa Number
422-
cache.H = J' * J
423-
cache.g = J' * fu
424-
else
425-
__update_JᵀJ!(Val{isinplace(cache)}(), cache, :H, J)
426-
__update_Jᵀf!(Val{isinplace(cache)}(), cache, :g, :H, J, fu)
427-
end
385+
function retrospective_step!(cache::TrustRegionCache{iip}) where {iip}
386+
J = jacobian!!(cache.J_cache, cache)
387+
__update_JᵀJ!(Val{iip}(), cache, :H, J)
388+
__update_Jᵀf!(Val{iip}(), cache, :g, :H, J, cache.fu)
428389
cache.stats.njacs += 1
429-
@unpack H, g, du = cache
430390

431-
return -(get_loss(fu_prev) - get_loss(fu)) /
432-
(dot(_vec(du), _vec(g)) + __lr_mul(Val(isinplace(cache)), H, _vec(du)) / 2)
391+
# FIXME: Caching in __lr_mul
392+
num = __get_trust_region_loss(cache.fu) - __get_trust_region_loss(cache.fu_cache)
393+
denom = dot(_vec(du), _vec(g)) + __lr_mul(Val{iip}(), H, _vec(du)) / 2
394+
return num / denom
433395
end
434396

397+
# TODO
435398
function trust_region_step!(cache::TrustRegionCache)
436399
@unpack fu_new, du, g, H, loss, max_trust_r, radius_update_scheme = cache
437400

438-
cache.loss_new = get_loss(fu_new)
401+
cache.loss_new = __get_trust_region_loss(fu_new)
439402

440403
# Compute the ratio of the actual reduction to the predicted reduction.
441404
cache.r = -(loss - cache.loss_new) /
@@ -556,6 +519,7 @@ function trust_region_step!(cache::TrustRegionCache)
556519
end
557520

558521
@unpack p1 = cache
522+
# TODO: Use the `vjp_autodiff` to for the jvp
559523
cache.trust_r = p1 * cache.internalnorm(jvp!(cache))
560524

561525
update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J,
@@ -608,6 +572,7 @@ function trust_region_step!(cache::TrustRegionCache)
608572
end
609573
end
610574

575+
# TODO
611576
function dogleg!(cache::TrustRegionCache{true})
612577
@unpack u_tmp, u_gauss_newton, u_cauchy, trust_r = cache
613578

@@ -638,6 +603,7 @@ function dogleg!(cache::TrustRegionCache{true})
638603
@. cache.du = u_cauchy + τ * u_tmp
639604
end
640605

606+
# TODO
641607
function dogleg!(cache::TrustRegionCache{false})
642608
@unpack u_tmp, u_gauss_newton, u_cauchy, trust_r = cache
643609

@@ -667,20 +633,14 @@ function dogleg!(cache::TrustRegionCache{false})
667633
cache.du = u_cauchy + τ * u_tmp
668634
end
669635

670-
function take_step!(cache::TrustRegionCache{true})
671-
cache.u_prev .= cache.u
672-
cache.u .= cache.u_tmp
673-
cache.fu_prev .= cache.fu
674-
cache.fu .= cache.fu_new
675-
end
676-
677-
function take_step!(cache::TrustRegionCache{false})
678-
cache.u_prev = cache.u
679-
cache.u = cache.u_tmp
680-
cache.fu_prev = cache.fu
681-
cache.fu = cache.fu_new
636+
function __take_step!(cache::TrustRegionCache)
637+
@bb copyto!(cache.u_cache, cache.u)
638+
@bb copyto!(cache.u, cache.u_cache_2) # u_tmp --> u_cache_2
639+
@bb copyto!(cache.fu_cache, cache.fu)
640+
@bb copyto!(cache.fu, cache.fu_cache_2) # fu_new --> fu_cache_2
682641
end
683642

643+
# TODO
684644
function jvp!(cache::TrustRegionCache{false})
685645
@unpack f, u, fu, uf = cache
686646
if isa(u, Number)
@@ -710,40 +670,15 @@ function not_terminated(cache::TrustRegionCache)
710670
end
711671
return true
712672
end
713-
get_fu(cache::TrustRegionCache) = cache.fu
714-
set_fu!(cache::TrustRegionCache, fu) = (cache.fu = fu)
715-
716-
function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache.p,
717-
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
718-
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
719-
cache.p = p
720-
if iip
721-
recursivecopy!(cache.u, u0)
722-
cache.f(cache.fu, cache.u, p)
723-
else
724-
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
725-
cache.u = u0
726-
cache.fu = cache.f(cache.u, p)
727-
end
728-
729-
reset!(cache.trace)
730-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
731-
termination_condition)
732673

733-
cache.abstol = abstol
734-
cache.reltol = reltol
735-
cache.tc_cache = tc_cache
736-
cache.maxiters = maxiters
737-
cache.stats.nf = 1
738-
cache.stats.nsteps = 1
739-
cache.force_stop = false
740-
cache.retcode = ReturnCode.Default
741-
cache.make_new_J = true
742-
cache.loss = get_loss(cache.fu)
674+
function __reinit_internal!(cache::TrustRegionCache; kwargs...)
675+
cache.loss = __get_trust_region_loss(cache.fu)
743676
cache.shrink_counter = 0
744-
cache.trust_r = convert(eltype(cache.u), cache.alg.initial_trust_radius)
745-
if iszero(cache.trust_r)
746-
cache.trust_r = convert(eltype(cache.u), cache.max_trust_r / 11)
747-
end
748-
return cache
677+
cache.trust_r = convert(eltype(cache.u),
678+
ifelse(cache.alg.initial_trust_radius == 0, cache.alg.initial_trust_radius,
679+
cache.max_trust_r / 11))
680+
cache.make_new_J = true
681+
return nothing
749682
end
683+
684+
__get_trust_region_loss(fu) = norm(fu)^2 / 2

src/utils.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ function wrapprecs(_Pl, _Pr, weight)
151151
return Pl, Pr
152152
end
153153

154-
get_loss(fu) = norm(fu)^2 / 2
155-
156154
function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-function for adaptive trust region method
157155
if (r c2)
158156
return (2 * (M - 1 - γ2) * atan(r - c2) + (1 + γ2)) / π
@@ -188,7 +186,7 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
188186
return fu
189187
end
190188

191-
function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
189+
function evaluate_f(f::F, u, p, ::Val{iip}; fu = nothing) where {F, iip <: Bool}
192190
if iip
193191
f(fu, u, p)
194192
return fu
@@ -197,11 +195,20 @@ function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
197195
end
198196
end
199197

200-
function evaluate_f(cache, u, p)
201-
if isinplace(cache)
202-
cache.prob.f(get_fu(cache), u, p)
198+
function evaluate_f(cache::AbstractNonlinearSolveCache, u, p,
199+
fu_sym::Val{FUSYM} = Val(nothing)) where {FUSYM}
200+
if FUSYM === nothing
201+
if isinplace(cache)
202+
cache.prob.f(get_fu(cache), u, p)
203+
else
204+
set_fu!(cache, cache.prob.f(u, p))
205+
end
203206
else
204-
set_fu!(cache, cache.prob.f(u, p))
207+
if isinplace(cache)
208+
cache.prob.f(__getproperty(cache, fu_sym), u, p)
209+
else
210+
setproperty!(cache, FUSYM, cache.prob.f(u, p))
211+
end
205212
end
206213
return nothing
207214
end

0 commit comments

Comments
 (0)