Skip to content

Commit 28e39dd

Browse files
committed
Kind of finish LM
1 parent 51f4a3e commit 28e39dd

File tree

2 files changed

+100
-69
lines changed

2 files changed

+100
-69
lines changed

src/levenberg.jl

Lines changed: 82 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
179179
else
180180
uf, linsolve, J, fu_cache, jac_cache, du = jacobian_caches(alg, f, u, p,
181181
Val(iip); linsolve_kwargs, linsolve_with_JᵀJ = Val(false))
182-
@bb JᵀJ = similar(u)
182+
u_ = _vec(u)
183+
@bb JᵀJ = similar(u_)
183184
@bb v = similar(du)
184185
end
185186

@@ -241,91 +242,103 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
241242
end
242243

243244
function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip, fastls}
245+
@unpack alg, linsolve = cache
246+
244247
if cache.make_new_J
245248
cache.J = jacobian!!(cache.J, cache)
246249
if fastls
247250
cache.JᵀJ = __sum_JᵀJ!!(cache.JᵀJ, cache.J)
248-
# cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)
249251
else
250252
@bb cache.JᵀJ = transpose(cache.J) × cache.J
251-
# cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
252253
end
254+
cache.DᵀD = __update_LM_diagonal!!(cache.DᵀD, cache.JᵀJ)
253255
cache.make_new_J = false
254256
end
255257

256-
# @unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
257-
258258
# Usual Levenberg-Marquardt step ("velocity").
259259
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
260-
# if fastls
261-
# copyto!(@view(cache.mat_tmp[1:length(fu1), :]), cache.J)
262-
# cache.mat_tmp[(length(fu1) + 1):end, :] .= λ .* cache.DᵀD
263-
# cache.rhs_tmp[1:length(fu1)] .= _vec(fu1)
264-
# linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
265-
# b = cache.rhs_tmp, linu = _vec(cache.du), p = p, reltol = cache.abstol)
266-
# _vec(cache.v) .= -_vec(cache.du)
267-
# else
268-
# mul!(_vec(cache.u_tmp), J', _vec(fu1))
269-
# @. cache.mat_tmp = JᵀJ + λ * DᵀD
270-
# linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
271-
# b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
272-
# cache.linsolve = linres.cache
273-
# _vec(cache.v) .= -_vec(cache.du)
274-
# end
275-
276-
# update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J,
277-
# cache.v)
278-
279-
# # Geodesic acceleration (step_size = v + a / 2).
280-
# @unpack v, α_geodesic, h = cache
281-
# cache.u_tmp .= _restructure(cache.u_tmp, _vec(u) .+ h .* _vec(v))
282-
# f(cache.fu_tmp, cache.u_tmp, p)
283-
284-
# # The following lines do: cache.a = -J \ cache.fu_tmp
285-
# # NOTE: Don't pass `A` in again, since we want to reuse the previous solve
286-
# mul!(_vec(cache.Jv), J, _vec(v))
287-
# @. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
288-
# if fastls
289-
# cache.rhs_tmp[1:length(fu1)] .= _vec(cache.fu_tmp)
290-
# linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.du),
291-
# p = p, reltol = cache.abstol)
292-
# else
293-
# mul!(_vec(cache.u_tmp), J', _vec(cache.fu_tmp))
294-
# linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
295-
# linu = _vec(cache.du), p = p, reltol = cache.abstol)
296-
# cache.linsolve = linres.cache
297-
# @. cache.a = -cache.du
298-
# end
260+
if fastls
261+
if setindex_trait(cache.mat_tmp) === CanSetindex()
262+
copyto!(@view(cache.mat_tmp[1:length(cache.fu), :]), cache.J)
263+
cache.mat_tmp[(length(cache.fu) + 1):end, :] .= cache.λ .* cache.DᵀD
264+
else
265+
cache.mat_tmp = _vcat(cache.J, cache.λ .* cache.DᵀD)
266+
end
267+
if setindex_trait(cache.rhs_tmp) === CanSetindex()
268+
cache.rhs_tmp[1:length(cache.fu)] .= _vec(cache.fu)
269+
else
270+
cache.rhs_tmp = _vcat(_vec(cache.fu), zero(_vec(cache.u)))
271+
end
272+
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
273+
b = cache.rhs_tmp, linu = _vec(cache.v), cache.p, reltol = cache.abstol)
274+
@bb @. cache.v = -linres.u
275+
else
276+
@bb cache.u_cache_2 = transpose(J) × cache.fu
277+
@bb @. cache.mat_tmp = cache.JᵀJ + cache.λ * cache.DᵀD
278+
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
279+
b = _vec(cache.u_cache_2), linu = _vec(cache.v), cache.p, reltol = cache.abstol)
280+
cache.linsolve = linres.cache
281+
@bb @. cache.v = -linres.u
282+
end
283+
284+
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J,
285+
cache.v)
286+
287+
# Geodesic acceleration (step_size = v + a / 2).
288+
@bb @. cache.u_cache_2 = cache.u + cache.h * cache.v
289+
evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2))
290+
291+
# The following lines do: cache.a = -J \ cache.fu_tmp
292+
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
293+
@bb cache.Jv = cache.J × cache.v
294+
@bb @. cache.fu_cache_2 = (2 / cache.h) *
295+
((cache.fu_cache_2 - cache.fu) / cache.h - cache.Jv)
296+
if fastls
297+
if setindex_trait(cache.rhs_tmp) === CanSetindex()
298+
cache.rhs_tmp[1:length(cache.fu)] .= _vec(cache.fu_cache_2)
299+
else
300+
cache.rhs_tmp = _vcat(_vec(cache.fu_cache_2), zero(_vec(cache.u)))
301+
end
302+
linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.a),
303+
cache.p, reltol = cache.abstol)
304+
@bb @. cache.a = -linres.u
305+
else
306+
@bb cache.u_cache_2 = transpose(J) × cache.fu_cache_2
307+
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_cache_2),
308+
linu = _vec(cache.a), cache.p, reltol = cache.abstol)
309+
cache.linsolve = linres.cache
310+
@bb @. cache.a = -linres.du
311+
end
299312

300313
cache.stats.nsolve += 2
301314
cache.stats.nfactors += 2
302315

303316
# Require acceptable steps to satisfy the following condition.
304-
norm_v = cache.internalnorm(v)
305-
if 2 * cache.internalnorm(cache.a) α_geodesic * norm_v
306-
# _vec(cache.δ) .= _vec(v) .+ _vec(cache.a) ./ 2
307-
# @unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
308-
# f(cache.fu_tmp, u .+ δ, p)
309-
# loss = cache.internalnorm(cache.fu_tmp)
310-
311-
# # Condition to accept uphill steps (evaluates to `loss ≤ loss_old` in iteration 1).
312-
# β = dot(v, v_old) / (norm_v * norm_v_old)
313-
# if (1 - β)^b_uphill * loss ≤ loss_old
314-
# # Accept step.
315-
# cache.u .+= δ
316-
# check_and_update!(cache.tc_cache_1, cache, cache.fu_tmp, cache.u, cache.u_prev)
317-
# if !cache.force_stop && cache.tc_cache_2 !== nothing
318-
# # For NLLS Problems
319-
# cache.fu1 .= cache.fu_tmp .- cache.fu1
320-
# check_and_update!(cache.tc_cache_2, cache, cache.fu1, cache.u, cache.u_prev)
321-
# end
322-
# cache.fu1 .= cache.fu_tmp
323-
# _vec(cache.v_old) .= _vec(v)
324-
# cache.norm_v_old = norm_v
325-
# cache.loss_old = loss
326-
# cache.λ_factor = 1 / cache.damping_decrease_factor
327-
# cache.make_new_J = true
328-
# end
317+
norm_v = cache.internalnorm(cache.v)
318+
if 2 * cache.internalnorm(cache.a) cache.α_geodesic * norm_v
319+
@bb @. cache.du_cache = cache.v + cache.a / 2
320+
@bb @. cache.u_cache_2 = cache.u + cache.du_cache
321+
evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2))
322+
loss = cache.internalnorm(cache.fu_cache_2)
323+
324+
# Condition to accept uphill steps (evaluates to `loss ≤ loss_old` in iteration 1).
325+
β = dot(cache.v, cache.v_cache) / (norm_v * cache.norm_v_old)
326+
if (1 - β)^cache.b_uphill * loss cache.loss_old
327+
# Accept step.
328+
@bb copyto!(cache.u, cache.u_cache_2)
329+
check_and_update!(cache.tc_cache_1, cache, cache.fu_cache, cache.u,
330+
cache.u_cache)
331+
if !cache.force_stop && cache.tc_cache_2 !== nothing # For NLLS Problems
332+
@bb @. cache.fu = cache.fu_cache_2 - cache.fu
333+
check_and_update!(cache.tc_cache_2, cache, cache.fu, cache.u, cache.u_cache)
334+
end
335+
@bb copyto!(cache.fu_cache, cache.fu_cache_2)
336+
@bb copyto!(cache.v_cache, cache.v)
337+
cache.norm_v_old = norm_v
338+
cache.loss_old = loss
339+
cache.λ_factor = 1 / cache.damping_decrease_factor
340+
cache.make_new_J = true
341+
end
329342
end
330343

331344
@bb copyto!(cache.u_cache, cache.u)

src/utils.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,21 @@ function __sum_JᵀJ!!(y, J)
447447
return sum(abs2, J'; dims = 1)
448448
end
449449
end
450+
451+
function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector)
452+
if setindex_trait(y.diag) === CanSetindex()
453+
@. y.diag = max(y.diag, x)
454+
return y
455+
else
456+
return Diagonal(max.(y.diag, x))
457+
end
458+
end
459+
@views function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
460+
x_diag = x[diagind(x)]
461+
if setindex_trait(y.diag) === CanSetindex()
462+
@. y.diag = max(y.diag, x_diag)
463+
return y
464+
else
465+
return Diagonal(max.(y.diag, x_diag))
466+
end
467+
end

0 commit comments

Comments
 (0)