Skip to content

Commit 5af2da0

Browse files
Fix levenburg
1 parent 1d0c424 commit 5af2da0

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/levenberg.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,21 +209,22 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
209209

210210
# Usual Levenberg-Marquardt step ("velocity").
211211
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
212-
mul!(cache.u_tmp, J', fu1)
212+
mul!(_vec(cache.u_tmp), J', _vec(fu1))
213213
@. cache.mat_tmp = JᵀJ + λ * DᵀD
214214
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
215215
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
216216
cache.linsolve = linres.cache
217-
@. cache.v = -cache.du
217+
_vec(cache.v) .= .-_vec(cache.du)
218218

219219
# Geodesic acceleration (step_size = v + a / 2).
220220
@unpack v, α_geodesic, h = cache
221-
f(cache.fu_tmp, u .+ h .* v, p)
221+
_vec(cache.du) .= _vec(u) .+ h .* _vec(v)
222+
f(cache.fu_tmp, cache.du, p)
222223

223224
# The following lines do: cache.a = -J \ cache.fu_tmp
224-
mul!(cache.Jv, J, v)
225+
mul!(_vec(cache.Jv), J, _vec(v))
225226
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
226-
mul!(cache.u_tmp, J', cache.fu_tmp)
227+
mul!(_vec(cache.u_tmp), J', _vec(cache.fu_tmp))
227228
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
228229
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
229230
linu = _vec(cache.du), p = p, reltol = cache.abstol)
@@ -235,7 +236,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
235236
# Require acceptable steps to satisfy the following condition.
236237
norm_v = norm(v)
237238
if 2 * norm(cache.a) α_geodesic * norm_v
238-
@. cache.δ = v + cache.a / 2
239+
_vec(cache.δ) .= _vec(v) .+ _vec(cache.a) ./ 2
239240
@unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
240241
f(cache.fu_tmp, u .+ δ, p)
241242
cache.stats.nf += 1
@@ -251,7 +252,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
251252
return nothing
252253
end
253254
cache.fu1 .= cache.fu_tmp
254-
cache.v_old .= v
255+
_vec(cache.v_old) .= _vec(v)
255256
cache.norm_v_old = norm_v
256257
cache.loss_old = loss
257258
cache.λ_factor = 1 / cache.damping_decrease_factor

0 commit comments

Comments
 (0)