Skip to content

Commit c6992a5

Browse files
committed
Broyden with LineSearch
1 parent 6b83054 commit c6992a5

File tree

8 files changed

+352
-37
lines changed

8 files changed

+352
-37
lines changed

src/NonlinearSolve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import UnPack: @unpack
2626
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
2727
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
2828

29+
abstract type AbstractNonlinearSolveLineSearchAlgorithm end
30+
2931
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
3032
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end
3133

@@ -105,6 +107,6 @@ export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, Pseu
105107
export LeastSquaresOptimJL, FastLevenbergMarquardtJL
106108
export RobustMultiNewton, FastShortcutNonlinearPolyalg
107109

108-
export LineSearch
110+
export LineSearch, LiFukushimaLineSearch
109111

110112
end # module

src/broyden.jl

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
# Sadly `Broyden` is taken up by SimpleNonlinearSolve.jl
22
"""
3-
GeneralBroyden(max_resets)
4-
GeneralBroyden(; max_resets = 3)
3+
GeneralBroyden(max_resets, linesearch)
4+
GeneralBroyden(; max_resets = 3, linesearch = LineSearch())
55
6-
An implementation of `Broyden` with support for caching!
6+
An implementation of `Broyden` with reseting and line search.
77
88
## Arguments
99
1010
- `max_resets`: the maximum number of resets to perform. Defaults to `3`.
11+
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
12+
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
13+
used here directly, and they will be converted to the correct `LineSearch`. It is
14+
recommended to use [LiFukushimaLineSearchCache](@ref) -- a derivative free linesearch
15+
specifically designed for Broyden's method.
1116
"""
12-
struct GeneralBroyden <: AbstractNewtonAlgorithm{false, Nothing}
17+
@concrete struct GeneralBroyden <: AbstractNewtonAlgorithm{false, Nothing}
1318
max_resets::Int
19+
linesearch
1420
end
1521

16-
GeneralBroyden(; max_resets = 3) = GeneralBroyden(max_resets)
22+
function GeneralBroyden(; max_resets = 3, linesearch = LineSearch())
23+
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
24+
return GeneralBroyden(max_resets, linesearch)
25+
end
1726

1827
@concrete mutable struct GeneralBroydenCache{iip} <: AbstractNonlinearSolveCache{iip}
1928
f
@@ -29,13 +38,14 @@ GeneralBroyden(; max_resets = 3) = GeneralBroyden(max_resets)
2938
J⁻¹df
3039
force_stop::Bool
3140
resets::Int
32-
max_rests::Int
41+
max_resets::Int
3342
maxiters::Int
3443
internalnorm
3544
retcode::ReturnCode.T
3645
abstol
3746
prob
3847
stats::NLStats
48+
lscache
3949
end
4050

4151
get_fu(cache::GeneralBroydenCache) = cache.fu
@@ -46,19 +56,20 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
4656
@unpack f, u0, p = prob
4757
u = alias_u0 ? u0 : deepcopy(u0)
4858
fu = evaluate_f(prob, u)
49-
J⁻¹ = convert(parameterless_type(_mutable(u)),
50-
Matrix{eltype(u)}(I, length(fu), length(u)))
51-
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, similar(fu),
52-
similar(fu), p, J⁻¹, similar(fu'), _mutable_zero(u), false, 0, alg.max_resets,
53-
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0))
59+
J⁻¹ = __init_identity_jacobian(u, fu)
60+
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),
61+
zero(fu), p, J⁻¹, zero(fu'), _mutable_zero(u), false, 0, alg.max_resets,
62+
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
63+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
5464
end
5565

5666
function perform_step!(cache::GeneralBroydenCache{true})
5767
@unpack f, p, du, fu, fu2, dfu, u, J⁻¹, J⁻¹df, J⁻¹₂ = cache
5868
T = eltype(u)
5969

6070
mul!(du, J⁻¹, -fu)
61-
u .+= du
71+
α = perform_linesearch!(cache.lscache, u, du)
72+
axpy!(α, du, u)
6273
f(fu2, u, p)
6374

6475
cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true)
@@ -68,7 +79,7 @@ function perform_step!(cache::GeneralBroydenCache{true})
6879

6980
# Update the inverse jacobian
7081
dfu .= fu2 .- fu
71-
if cache.resets < cache.max_rests &&
82+
if cache.resets < cache.max_resets &&
7283
(all(x -> abs(x) 1e-12, du) || all(x -> abs(x) 1e-12, dfu))
7384
fill!(J⁻¹, 0)
7485
J⁻¹[diagind(J⁻¹)] .= T(1)
@@ -83,3 +94,57 @@ function perform_step!(cache::GeneralBroydenCache{true})
8394

8495
return nothing
8596
end
97+
98+
function perform_step!(cache::GeneralBroydenCache{false})
99+
@unpack f, p = cache
100+
T = eltype(cache.u)
101+
102+
cache.du = cache.J⁻¹ * -cache.fu
103+
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
104+
cache.u = cache.u .+ α * cache.du
105+
cache.fu2 = f(cache.u, p)
106+
107+
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
108+
cache.stats.nf += 1
109+
110+
cache.force_stop && return nothing
111+
112+
# Update the inverse jacobian
113+
cache.dfu = cache.fu2 .- cache.fu
114+
if cache.resets < cache.max_resets &&
115+
(all(x -> abs(x) 1e-12, cache.du) || all(x -> abs(x) 1e-12, cache.dfu))
116+
J⁻¹ = similar(cache.J⁻¹)
117+
fill!(J⁻¹, 0)
118+
J⁻¹[diagind(J⁻¹)] .= T(1)
119+
cache.J⁻¹ = J⁻¹
120+
cache.resets += 1
121+
else
122+
cache.J⁻¹df = cache.J⁻¹ * cache.dfu
123+
cache.J⁻¹₂ = cache.du' * cache.J⁻¹
124+
cache.du = (cache.du .- cache.J⁻¹df) ./ (dot(cache.du, cache.J⁻¹df) .+ T(1e-5))
125+
cache.J⁻¹ = cache.J⁻¹ .+ cache.du * cache.J⁻¹₂
126+
end
127+
cache.fu = cache.fu2
128+
129+
return nothing
130+
end
131+
132+
function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
133+
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
134+
cache.p = p
135+
if iip
136+
recursivecopy!(cache.u, u0)
137+
cache.f(cache.fu, cache.u, p)
138+
else
139+
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
140+
cache.u = u0
141+
cache.fu = cache.f(cache.u, p)
142+
end
143+
cache.abstol = abstol
144+
cache.maxiters = maxiters
145+
cache.stats.nf = 1
146+
cache.stats.nsteps = 1
147+
cache.force_stop = false
148+
cache.retcode = ReturnCode.Default
149+
return cache
150+
end

src/klement.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

src/linesearch.jl

Lines changed: 134 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@ function LineSearch(; method = Static(), autodiff = AutoFiniteDiff(), alpha = tr
2626
return LineSearch(method, autodiff, alpha)
2727
end
2828

29-
@concrete mutable struct LineSearchCache
29+
@inline function init_linesearch_cache(ls::LineSearch, args...)
30+
return init_linesearch_cache(ls.method, ls, args...)
31+
end
32+
33+
# LineSearches.jl doesn't have a supertype so default to that
34+
init_linesearch_cache(_, ls, f, u, p, fu, iip) = LineSearchesJLCache(ls, f, u, p, fu, iip)
35+
36+
# Wrapper over LineSearches.jl algorithms
37+
@concrete mutable struct LineSearchesJLCache
3038
f
3139
ϕ
3240
@@ -35,11 +43,11 @@ end
3543
ls
3644
end
3745

38-
function LineSearchCache(ls::LineSearch, f, u::Number, p, _, ::Val{false})
46+
function LineSearchesJLCache(ls::LineSearch, f, u::Number, p, _, ::Val{false})
3947
eval_f(u, du, α) = eval_f(u - α * du)
4048
eval_f(u) = f(u, p)
4149

42-
ls.method isa Static && return LineSearchCache(eval_f, nothing, nothing, nothing,
50+
ls.method isa Static && return LineSearchesJLCache(eval_f, nothing, nothing, nothing,
4351
convert(typeof(u), ls.α), ls)
4452

4553
g(u, fu) = last(value_derivative(Base.Fix2(f, p), u)) * fu
@@ -73,11 +81,11 @@ function LineSearchCache(ls::LineSearch, f, u::Number, p, _, ::Val{false})
7381
return ϕdϕ_internal
7482
end
7583

76-
return LineSearchCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
84+
return LineSearchesJLCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
7785
end
7886

79-
function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip}
80-
fu = iip ? fu1 : nothing
87+
function LineSearchesJLCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip}
88+
fu = iip ? deepcopy(fu1) : nothing
8189
u_ = _mutable_zero(u)
8290

8391
function eval_f(u, du, α)
@@ -86,7 +94,7 @@ function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip
8694
end
8795
eval_f(u) = evaluate_f(f, u, p, IIP; fu)
8896

89-
ls.method isa Static && return LineSearchCache(eval_f, nothing, nothing, nothing,
97+
ls.method isa Static && return LineSearchesJLCache(eval_f, nothing, nothing, nothing,
9098
convert(eltype(u), ls.α), ls)
9199

92100
g₀ = _mutable_zero(u)
@@ -138,10 +146,10 @@ function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip
138146
return ϕdϕ_internal
139147
end
140148

141-
return LineSearchCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
149+
return LineSearchesJLCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
142150
end
143151

144-
function perform_linesearch!(cache::LineSearchCache, u, du)
152+
function perform_linesearch!(cache::LineSearchesJLCache, u, du)
145153
cache.ls.method isa Static && return cache.α
146154

147155
ϕ = cache.ϕ(u, du)
@@ -155,3 +163,120 @@ function perform_linesearch!(cache::LineSearchCache, u, du)
155163

156164
return first(cache.ls.method(ϕ, cache.(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀))
157165
end
166+
167+
"""
168+
LiFukushimaLineSearch(; lambda_0 = 1.0, beta = 0.5, sigma_1 = 0.001,
169+
eta = 0.1, nan_max_iter = 5, maxiters = 50)
170+
171+
A derivative-free line search and global convergence of Broyden-like method for nonlinear
172+
equations by Dong-Hui Li & Masao Fukushima. For more details see
173+
https://doi.org/10.1080/10556780008805782
174+
"""
175+
struct LiFukushimaLineSearch{T} <: AbstractNonlinearSolveLineSearchAlgorithm
176+
λ₀::T
177+
β::T
178+
σ₁::T
179+
σ₂::T
180+
η::T
181+
ρ::T
182+
nan_max_iter::Int
183+
maxiters::Int
184+
end
185+
186+
function LiFukushimaLineSearch(; lambda_0 = 1.0, beta = 0.1, sigma_1 = 0.001,
187+
sigma_2 = 0.001, eta = 0.1, rho = 0.9, nan_max_iter = 5, maxiters = 50)
188+
T = promote_type(typeof(lambda_0), typeof(beta), typeof(sigma_1), typeof(eta),
189+
typeof(rho), typeof(sigma_2))
190+
return LiFukushimaLineSearch{T}(lambda_0, beta, sigma_1, sigma_2, eta, rho,
191+
nan_max_iter, maxiters)
192+
end
193+
194+
@concrete mutable struct LiFukushimaLineSearchCache{iip}
195+
f
196+
p
197+
u_cache
198+
fu_cache
199+
alg
200+
α
201+
end
202+
203+
function init_linesearch_cache(alg::LiFukushimaLineSearch, ls::LineSearch, f, _u, p, _fu,
204+
::Val{iip}) where {iip}
205+
fu = iip ? deepcopy(_fu) : nothing
206+
u = iip ? deepcopy(_u) : nothing
207+
return LiFukushimaLineSearchCache{iip}(f, p, u, fu, alg, ls.α)
208+
end
209+
210+
function perform_linesearch!(cache::LiFukushimaLineSearchCache{iip}, u, du) where {iip}
211+
(; β, σ₁, σ₂, η, λ₀, ρ, nan_max_iter, maxiters) = cache.alg
212+
λ₂ = λ₀
213+
λ₁ = λ₂
214+
215+
if iip
216+
cache.f(cache.fu_cache, u, cache.p)
217+
fx_norm = norm(cache.fu_cache, 2)
218+
else
219+
fx_norm = norm(cache.f(u, cache.p), 2)
220+
end
221+
222+
# Non-Blocking exit if the norm is NaN or Inf
223+
!isfinite(fx_norm) && return cache.α
224+
225+
# Early Terminate based on Eq. 2.7
226+
if iip
227+
cache.u_cache .= u .+ du
228+
cache.f(cache.fu_cache, cache.u_cache, cache.p)
229+
fxλ_norm = norm(cache.fu_cache, 2)
230+
else
231+
fxλ_norm = norm(cache.f(u .+ du, cache.p), 2)
232+
end
233+
234+
fxλ_norm ρ * fx_norm - σ₂ * norm(du, 2)^2 && return cache.α
235+
236+
if iip
237+
cache.u_cache .= u .+ λ₂ .* du
238+
cache.f(cache.fu_cache, cache.u_cache, cache.p)
239+
fxλp_norm = norm(cache.fu_cache, 2)
240+
else
241+
fxλp_norm = norm(cache.f(u .+ λ₂ .* du, cache.p), 2)
242+
end
243+
244+
if !isfinite(fxλp_norm)
245+
# Backtrack a finite number of steps
246+
nan_converged = false
247+
for _ in 1:nan_max_iter
248+
λ₁, λ₂ = λ₂, β * λ₂
249+
250+
if iip
251+
cache.u_cache .= u .+ λ₂ .* du
252+
cache.f(cache.fu_cache, cache.u_cache, cache.p)
253+
fxλp_norm = norm(cache.fu_cache, 2)
254+
else
255+
fxλp_norm = norm(cache.f(u .+ λ₂ .* du, cache.p), 2)
256+
end
257+
258+
nan_converged = isfinite(fxλp_norm)
259+
nan_converged && break
260+
end
261+
262+
# Non-Blocking exit if the norm is still NaN or Inf
263+
!nan_converged && return cache.α
264+
end
265+
266+
for _ in 1:maxiters
267+
if iip
268+
cache.u_cache .= u .+ λ₂ .* du
269+
cache.f(cache.fu_cache, cache.u_cache, cache.p)
270+
fxλp_norm = norm(cache.fu_cache, 2)
271+
else
272+
fxλp_norm = norm(cache.f(u .+ λ₂ .* du, cache.p), 2)
273+
end
274+
275+
converged = fxλp_norm (1 + η) * fx_norm - σ₁ * λ₂^2 * norm(du, 2)^2
276+
277+
converged && break
278+
λ₁, λ₂ = λ₂, β * λ₂
279+
end
280+
281+
return λ₂
282+
end

src/raphson.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
8282

8383
return NewtonRaphsonCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
8484
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
85-
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)))
85+
NLStats(1, 0, 0, 0, 0),
86+
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)))
8687
end
8788

8889
function perform_step!(cache::NewtonRaphsonCache{true})
@@ -96,7 +97,7 @@ function perform_step!(cache::NewtonRaphsonCache{true})
9697

9798
# Line Search
9899
α = perform_linesearch!(cache.lscache, u, du)
99-
@. u = u - α * du
100+
axpy!(α, du, u)
100101
f(cache.fu1, u, p)
101102

102103
cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)

src/utils.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ function default_adargs_to_adtype(; chunk_size = missing, autodiff = nothing,
3535
if chunk_size !== missing || standardtag !== missing || diff_type !== missing ||
3636
autodiff !== missing
3737
Base.depwarn("`chunk_size`, `standardtag`, `diff_type`, \
38-
`autodiff::Union{Val, Bool}` kwargs have been deprecated and will be removed in\
39-
v3. Update your code to directly specify autodiff=<ADTypes>",
38+
`autodiff::Union{Val, Bool}` kwargs have been deprecated and will be removed \
39+
in v3. Update your code to directly specify autodiff=<ADTypes>",
4040
:default_adargs_to_adtype)
4141
end
4242
chunk_size === missing && (chunk_size = Val{0}())
@@ -211,3 +211,13 @@ function __get_concrete_algorithm(alg, prob)
211211
end
212212
return set_ad(alg, ad)
213213
end
214+
215+
__init_identity_jacobian(u::Number, _) = u
216+
function __init_identity_jacobian(u, fu)
217+
return convert(parameterless_type(_mutable(u)),
218+
Matrix{eltype(u)}(I, length(fu), length(u)))
219+
end
220+
function __init_identity_jacobian(u::StaticArray, fu)
221+
return convert(MArray{Tuple{length(fu), length(u)}},
222+
Matrix{eltype(u)}(I, length(fu), length(u)))
223+
end

0 commit comments

Comments
 (0)