Skip to content

Commit 3259cd2

Browse files
Add LineSearchTestCase (#177)
* Add LineSearchTestCase Also includes the failing case in PR#174. Co-authored-by: Mateusz Baran <[email protected]> * Add caching to all line search algorithms * Add to docs * Test caching for all algs --------- Co-authored-by: Mateusz Baran <[email protected]>
1 parent ded667a commit 3259cd2

File tree

15 files changed

+238
-25
lines changed

15 files changed

+238
-25
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
docs/build
55
docs/src/examples/generated
66
/docs/Manifest.toml
7+
Manifest.toml

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LineSearches"
22
uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
3-
version = "7.2.0"
3+
version = "7.3.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -13,6 +13,8 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1313
DoubleFloats = "1"
1414
NLSolversBase = "7"
1515
NaNMath = "1"
16+
Optim = "1"
17+
OptimTestProblems = "2"
1618
Parameters = "0.10, 0.11, 0.12"
1719
julia = "1.6"
1820

docs/src/index.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ using LineSearches
4646
```
4747
to load the package.
4848

49+
## Debugging
50+
51+
If you suspect a method of suboptimal performance or find that your code errors,
52+
create a [`LineSearchCache`](@ref) to record intermediate values for later
53+
inspection and analysis. If you're using this via Optim.jl, configure it inside
54+
the method, e.g., `Newton(linesearch=LineSearches.MoreThuente(; cache))`. The
55+
value stored in the cache will reflect the final iteration of line search during
56+
optimization.
4957

5058
## References
5159

docs/src/reference/linesearch.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,9 @@ MoreThuente
1111
Static
1212
StrongWolfe
1313
```
14+
15+
## Debugging
16+
17+
```@docs
18+
LineSearchCache
19+
```

src/LineSearches.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
__precompile__()
2-
31
module LineSearches
42

53
using Printf
@@ -9,13 +7,14 @@ using Parameters, NaNMath
97
import NLSolversBase
108
import NLSolversBase: AbstractObjective
119

12-
export LineSearchException
10+
export LineSearchException, LineSearchCache
1311

14-
export BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe
12+
export AbstractLineSearch, BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe
1513

1614
export InitialHagerZhang, InitialStatic, InitialPrevious,
1715
InitialQuadratic, InitialConstantChange
1816

17+
1918
function make_ϕ(df, x_new, x, s)
2019
function ϕ(α)
2120
# Move a distance of alpha in the direction of s
@@ -91,6 +90,26 @@ end
9190

9291
include("types.jl")
9392

93+
# The following don't extend `empty!` and `push!` because we want implementations for `nothing`
94+
# and that would be piracy
95+
emptycache!(cache::LineSearchCache) = begin
96+
empty!(cache.alphas)
97+
empty!(cache.values)
98+
empty!(cache.slopes)
99+
end
100+
emptycache!(::Nothing) = nothing
101+
pushcache!(cache::LineSearchCache, α, val, slope) = begin
102+
push!(cache.alphas, α)
103+
push!(cache.values, val)
104+
push!(cache.slopes, slope)
105+
end
106+
pushcache!(cache::LineSearchCache, α, val) = begin
107+
push!(cache.alphas, α)
108+
push!(cache.values, val)
109+
end
110+
pushcache!(::Nothing, α, val, slope) = nothing
111+
pushcache!(::Nothing, α, val) = nothing
112+
94113
# Line Search Methods
95114
include("backtracking.jl")
96115
include("strongwolfe.jl")

src/backtracking.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ there exists a factor ρ = ρ(c₁) such that α' ≦ ρ α.
88
99
This is a modification of the algorithm described in Nocedal Wright (2nd ed), Sec. 3.5.
1010
"""
11-
@with_kw struct BackTracking{TF, TI}
11+
@with_kw struct BackTracking{TF, TI} <: AbstractLineSearch
1212
c_1::TF = 1e-4
1313
ρ_hi::TF = 0.5
1414
ρ_lo::TF = 0.1
1515
iterations::TI = 1_000
1616
order::TI = 3
1717
maxstep::TF = Inf
18+
cache::Union{Nothing,LineSearchCache{TF}} = nothing
1819
end
1920
BackTracking{TF}(args...; kwargs...) where TF = BackTracking{TF,Int}(args...; kwargs...)
2021

@@ -37,7 +38,9 @@ end
3738

3839
# TODO: Should we deprecate the interface that only uses the ϕ argument?
3940
function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where
40-
@unpack c_1, ρ_hi, ρ_lo, iterations, order = ls
41+
@unpack c_1, ρ_hi, ρ_lo, iterations, order, cache = ls
42+
emptycache!(cache)
43+
pushcache!(cache, 0, ϕ_0, dϕ_0) # backtracking doesn't use the slope except here
4144

4245
iterfinitemax = -log2(eps(real(Tα)))
4346

@@ -68,6 +71,8 @@ function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where Tα
6871

6972
ϕx_1 = ϕ(α_2)
7073
end
74+
pushcache!(cache, αinitial, ϕx_1)
75+
# TODO: check if value is finite (maybe iterfinite > iterfinitemax)
7176

7277
# Backtrack until we satisfy sufficient decrease condition
7378
while ϕx_1 > ϕ_0 + c_1 * α_2 * dϕ_0
@@ -112,6 +117,7 @@ function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where Tα
112117

113118
# Evaluate f(x) at proposed position
114119
ϕx_0, ϕx_1 = ϕx_1, ϕ(α_2)
120+
pushcache!(cache, α_2, ϕx_1)
115121
end
116122

117123
return α_2, ϕx_1

src/hagerzhang.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Conjugate gradient line search implementation from:
8080
conjugate gradient method with guaranteed descent. ACM
8181
Transactions on Mathematical Software 32: 113–137.
8282
"""
83-
@with_kw struct HagerZhang{T, Tm}
83+
@with_kw struct HagerZhang{T, Tm} <: AbstractLineSearch
8484
delta::T = DEFAULTDELTA # c_1 Wolfe sufficient decrease condition
8585
sigma::T = DEFAULTSIGMA # c_2 Wolfe curvature condition (Recommend 0.1 for GradientDescent)
8686
alphamax::T = Inf
@@ -91,6 +91,7 @@ Conjugate gradient line search implementation from:
9191
psi3::T = 0.1
9292
display::Int = 0
9393
mayterminate::Tm = Ref{Bool}(false)
94+
cache::Union{Nothing,LineSearchCache{T}} = nothing
9495
end
9596
HagerZhang{T}(args...; kwargs...) where T = HagerZhang{T, Base.RefValue{Bool}}(args...; kwargs...)
9697

@@ -109,9 +110,11 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
109110
phi_0::Real,
110111
dphi_0::Real) where T # Should c and phi_0 be same type?
111112
@unpack delta, sigma, alphamax, rho, epsilon, gamma,
112-
linesearchmax, psi3, display, mayterminate = ls
113+
linesearchmax, psi3, display, mayterminate, cache = ls
114+
emptycache!(cache)
113115

114116
zeroT = convert(T, 0)
117+
pushcache!(cache, zeroT, phi_0, dphi_0)
115118
if !(isfinite(phi_0) && isfinite(dphi_0))
116119
throw(LineSearchException("Value and slope at step length = 0 must be finite.", T(0)))
117120
end
@@ -124,9 +127,13 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
124127
# Prevent values of x_new = x+αs that are likely to make
125128
# ϕ(x_new) infinite
126129
iterfinitemax::Int = ceil(Int, -log2(eps(T)))
127-
alphas = [zeroT] # for bisection
128-
values = [phi_0]
129-
slopes = [dphi_0]
130+
if cache !== nothing
131+
@unpack alphas, values, slopes = cache
132+
else
133+
alphas = [zeroT] # for bisection
134+
values = [phi_0]
135+
slopes = [dphi_0]
136+
end
130137
if display & LINESEARCH > 0
131138
println("New linesearch")
132139
end
@@ -203,10 +210,10 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
203210
else
204211
# We'll still going downhill, expand the interval and try again.
205212
# Reaching this branch means that dphi_c < 0 and phi_c <= phi_0 + ϵ_k
206-
# So cold = c has a lower objective than phi_0 up to epsilon.
213+
# So cold = c has a lower objective than phi_0 up to epsilon.
207214
# This makes it a viable step to return if bracketing fails.
208215

209-
# Bracketing can fail if no cold < c <= alphamax can be found with finite phi_c and dphi_c.
216+
# Bracketing can fail if no cold < c <= alphamax can be found with finite phi_c and dphi_c.
210217
# Going back to the loop with c = cold will only result in infinite cycling.
211218
# So returning (cold, phi_cold) and exiting the line search is the best move.
212219
cold = c

src/morethuente.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,14 @@ The line search implementation from:
138138
Line search algorithms with guaranteed sufficient decrease.
139139
ACM Transactions on Mathematical Software (TOMS) 20.3 (1994): 286-307.
140140
"""
141-
@with_kw struct MoreThuente{T}
141+
@with_kw struct MoreThuente{T} <: AbstractLineSearch
142142
f_tol::T = 1e-4 # c_1 Wolfe sufficient decrease condition
143143
gtol::T = 0.9 # c_2 Wolfe curvature condition (Recommend 0.1 for GradientDescent)
144144
x_tol::T = 1e-8
145145
alphamin::T = 1e-16
146146
alphamax::T = 65536.0
147147
maxfev::Int = 100
148+
cache::Union{Nothing,LineSearchCache{T}} = nothing
148149
end
149150

150151
function (ls::MoreThuente)(df::AbstractObjective, x::AbstractArray{T},
@@ -161,13 +162,15 @@ function (ls::MoreThuente)(ϕdϕ,
161162
alpha::T,
162163
ϕ_0,
163164
dϕ_0) where T
164-
@unpack f_tol, gtol, x_tol, alphamin, alphamax, maxfev = ls
165+
@unpack f_tol, gtol, x_tol, alphamin, alphamax, maxfev, cache = ls
166+
emptycache!(cache)
165167

166168
iterfinitemax = -log2(eps(T))
167169
info = 0
168170
info_cstep = 1 # Info from step
169171

170172
zeroT = convert(T, 0)
173+
pushcache!(cache, zeroT, ϕ_0, dϕ_0)
171174

172175
#
173176
# Check the input parameters for errors.
@@ -236,7 +239,9 @@ function (ls::MoreThuente)(ϕdϕ,
236239
# Make stmax = (3/2)*alpha < 2alpha in the first iteration below
237240
stx = (convert(T, 7)/8)*alpha
238241
end
242+
pushcache!(cache, alpha, f, dg)
239243
# END: Ensure that the initial step provides finite function values
244+
# TODO: check if value is finite (maybe iterfinite > iterfinitemax)
240245

241246
while true
242247
#
@@ -282,6 +287,7 @@ function (ls::MoreThuente)(ϕdϕ,
282287
# and compute the directional derivative.
283288
#
284289
f, dg = ϕdϕ(alpha)
290+
pushcache!(cache, alpha, f, dg)
285291
nfev += 1 # This includes calls to f() and g!()
286292

287293
if isapprox(dg, 0, atol=eps(T)) # Should add atol value to MoreThuente

src/static.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
`Static` is intended for methods with well-scaled updates; i.e. Newton, on well-behaved problems.
55
"""
6-
struct Static end
6+
struct Static <: AbstractLineSearch end
77

88
function (ls::Static)(df::AbstractObjective, x, s, α, x_new = similar(x), ϕ_0 = nothing, dϕ_0 = nothing)
99
ϕ = make_ϕ(df, x_new, x, s)

src/strongwolfe.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ use `MoreThuente`, `HagerZhang` or `BackTracking`.
1414
* `c_2 = 0.9` : second (strong) Wolfe condition
1515
* `ρ = 2.0` : bracket growth
1616
"""
17-
@with_kw struct StrongWolfe{T}
17+
@with_kw struct StrongWolfe{T} <: AbstractLineSearch
1818
c_1::T = 1e-4
1919
c_2::T = 0.9
2020
ρ::T = 2.0
21+
cache::Union{Nothing,LineSearchCache{T}} = nothing
2122
end
2223

2324
"""
@@ -49,9 +50,11 @@ Both `alpha` and `ϕ(alpha)` are returned.
4950
"""
5051
function (ls::StrongWolfe)(ϕ, dϕ, ϕdϕ,
5152
alpha0::T, ϕ_0, dϕ_0) where T<:Real
52-
@unpack c_1, c_2, ρ = ls
53+
@unpack c_1, c_2, ρ, cache = ls
54+
emptycache!(cache)
5355

5456
zeroT = convert(T, 0)
57+
pushcache!(cache, zeroT, ϕ_0, dϕ_0)
5558

5659
# Step-sizes
5760
a_0 = zeroT
@@ -71,17 +74,21 @@ function (ls::StrongWolfe)(ϕ, dϕ, ϕdϕ,
7174

7275
while a_i < a_max
7376
ϕ_a_i = ϕ(a_i)
77+
pushcache!(cache, a_i, ϕ_a_i)
7478

7579
# Test Wolfe conditions
7680
if (ϕ_a_i > ϕ_0 + c_1 * a_i * dϕ_0) ||
7781
(ϕ_a_i >= ϕ_a_iminus1 && i > 1)
7882
a_star = zoom(a_iminus1, a_i,
7983
dϕ_0, ϕ_0,
80-
ϕ, dϕ, ϕdϕ)
84+
ϕ, dϕ, ϕdϕ, cache)
8185
return a_star, ϕ(a_star)
8286
end
8387

8488
dϕ_a_i = (a_i)
89+
if cache !== nothing
90+
push!(cache.slopes, dϕ_a_i)
91+
end
8592

8693
# Check condition 2
8794
if abs(dϕ_a_i) <= -c_2 * dϕ_0
@@ -91,7 +98,7 @@ function (ls::StrongWolfe)(ϕ, dϕ, ϕdϕ,
9198
# Check condition 3
9299
if dϕ_a_i >= zeroT # FIXME untested!
93100
a_star = zoom(a_i, a_iminus1,
94-
dϕ_0, ϕ_0, ϕ, dϕ, ϕdϕ)
101+
dϕ_0, ϕ_0, ϕ, dϕ, ϕdϕ, cache)
95102
return a_star, ϕ(a_star)
96103
end
97104

@@ -117,6 +124,7 @@ function zoom(a_lo::T,
117124
ϕ,
118125
dϕ,
119126
ϕdϕ,
127+
cache,
120128
c_1::Real = convert(T, 1)/10^4,
121129
c_2::Real = convert(T, 9)/10) where T
122130

@@ -133,8 +141,10 @@ function zoom(a_lo::T,
133141
iteration += 1
134142

135143
ϕ_a_lo, ϕprime_a_lo = ϕdϕ(a_lo)
144+
pushcache!(cache, a_lo, ϕ_a_lo, ϕprime_a_lo)
136145

137146
ϕ_a_hi, ϕprime_a_hi = ϕdϕ(a_hi)
147+
pushcache!(cache, a_hi, ϕ_a_hi, ϕprime_a_hi)
138148

139149
# Interpolate a_j
140150
if a_lo < a_hi
@@ -150,6 +160,7 @@ function zoom(a_lo::T,
150160

151161
# Evaluate ϕ(a_j)
152162
ϕ_a_j = ϕ(a_j)
163+
pushcache!(cache, a_j, ϕ_a_j)
153164

154165
# Check Armijo
155166
if (ϕ_a_j > ϕ_0 + c_1 * a_j * dϕ_0) ||
@@ -158,6 +169,9 @@ function zoom(a_lo::T,
158169
else
159170
# Evaluate ϕprime(a_j)
160171
ϕprime_a_j = (a_j)
172+
if cache !== nothing
173+
push!(cache.slopes, ϕprime_a_j)
174+
end
161175

162176
if abs(ϕprime_a_j) <= -c_2 * dϕ_0
163177
return a_j

0 commit comments

Comments
 (0)