Skip to content

Commit c6f69d6

Browse files
authored
Use gradient!(d, x). in perform_linesearch (#1207)
* Use gradient!(d, x). in perform_linesearch This shouldn't matter, but *if* somehow happened to be calling the objective function in the callback or some other way this would now produce the correct result. Previously, it used the most recent evaluation. Now, it will try to evaluate at state.x and if that was the last point evaluated we'll just use the cached result. This should always be the case unless someone messed with the objective outside. We could even assert that g_calls is constant around the call to be sure to catch it but then users wouldn't be able to call the gradient outside at their own expense of runtime. * Update perform_linesearch.jl
1 parent 3e1393d commit c6f69d6

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/utilities/perform_linesearch.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,27 @@ _alphaguess(a::Number) = LineSearches.InitialStatic(alpha = a)
99
# project_tangent! here, because we already did that inplace on gradient(d) after
1010
# the last evaluation (we basically just always do it)
1111
function reset_search_direction!(state, d, method::BFGS)
12+
gx = gradient!(d, state.x)
1213
if method.initial_invH === nothing
1314
n = length(state.x)
1415
T = typeof(state.invH)
1516
if method.initial_stepnorm === nothing
1617
state.invH .= _init_identity_matrix(state.x)
1718
else
18-
initial_scale = method.initial_stepnorm * inv(norm(gradient(d), Inf))
19+
initial_scale = method.initial_stepnorm * inv(norm(gx, Inf))
1920
state.invH .= _init_identity_matrix(state.x, initial_scale)
2021
end
2122
else
2223
state.invH .= method.initial_invH(state.x)
2324
end
2425
# copyto!(state.invH, method.initial_invH(state.x))
25-
state.s .= .-gradient(d)
26+
state.s .= .-gx
2627
return true
2728
end
2829

2930
function reset_search_direction!(state, d, method::LBFGS)
3031
state.pseudo_iteration = 1
31-
state.s .= .-gradient(d)
32+
state.s .= .-gradient!(d, state.x)
3233
return true
3334
end
3435

@@ -39,12 +40,14 @@ end
3940

4041
function perform_linesearch!(state, method, d)
4142
# Calculate search direction dphi0
42-
dphi_0 = real(dot(gradient(d), state.s))
43+
fx = value_gradient!(d, state.x)
44+
gx = gradient(d)
45+
dphi_0 = real(dot(gx, state.s))
4346
# reset the direction if it becomes corrupted
4447
if dphi_0 >= zero(dphi_0) && reset_search_direction!(state, d, method)
45-
dphi_0 = real(dot(gradient(d), state.s)) # update after direction reset
48+
dphi_0 = real(dot(gx, state.s)) # update after direction reset
4649
end
47-
phi_0 = value(d)
50+
phi_0 = value!(d, state.x)
4851

4952
# Guess an alpha
5053
method.alphaguess!(method.linesearch!, state, phi_0, dphi_0, d)

0 commit comments

Comments
 (0)