Skip to content

Commit 13e590e

Browse files
committed
LM Fixed
1 parent c8f7283 commit 13e590e

File tree

5 files changed

+32
-29
lines changed

5 files changed

+32
-29
lines changed

src/jacobian.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
138138
kwargs...) where {needsJᵀJ, F}
139139
# NOTE: Scalar `u` assumes scalar output from `f`
140140
uf = SciMLBase.JacobianWrapper{false}(f, p)
141-
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
142-
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u
141+
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u, u, u
143142
end
144143

145144
# Linear Solve Cache
146145
function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;))
147-
if alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;)
146+
if A isa Number ||
147+
(alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;))
148148
# Default handling for SArrays in LinearSolve is not great. Some parts are patched
149149
# but there are quite a few unnecessary allocations
150150
return FakeLinearSolveJLCache(A, b)

src/levenberg.jl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ end
120120
fu
121121
fu_cache
122122
fu_cache_2
123-
du
124-
du_cache
125123
J
126124
JᵀJ
127125
Jv
@@ -197,9 +195,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
197195

198196
loss = internalnorm(fu)
199197

200-
@bb a = similar(du)
201-
@bb v_old = copy(v)
202-
@bb δ = similar(du)
198+
a = du # `du` is not used anywhere, use it to store `a`
203199

204200
make_new_J = true
205201

@@ -215,8 +211,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
215211
trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du; kwargs...)
216212

217213
if !fastls
218-
@bb mat_tmp = similar(JᵀJ)
219-
@bb mat_tmp .*= T(0)
214+
@bb mat_tmp = zero(JᵀJ)
220215
rhs_tmp = nothing
221216
else
222217
mat_tmp = _vcat(J, DᵀD)
@@ -229,15 +224,14 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
229224
@bb u_cache = copy(u)
230225
@bb u_cache_2 = similar(u)
231226
@bb fu_cache_2 = similar(fu)
232-
@bb du_cache = similar(du)
233227
Jv = J * v
234-
@bb v_cache = similar(v)
228+
@bb v_cache = zero(v)
235229

236230
return LevenbergMarquardtCache{iip, fastls}(f, alg, u, u_cache, u_cache_2, fu, fu_cache,
237-
fu_cache_2, du, du_cache, J, JᵀJ, Jv, DᵀD, v, v_cache, a, mat_tmp, rhs_tmp, p, uf,
231+
fu_cache_2, J, JᵀJ, Jv, DᵀD, v, v_cache, a, mat_tmp, rhs_tmp, p, uf,
238232
linsolve, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
239233
reltol, prob, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h,
240-
α_geodesic, b_uphill, min_damping_D, internalnorm(v_cache), loss, make_new_J,
234+
α_geodesic, b_uphill, min_damping_D, loss, loss, make_new_J,
241235
NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2, trace)
242236
end
243237

@@ -271,11 +265,12 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
271265
end
272266
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
273267
b = cache.rhs_tmp, linu = _vec(cache.v), cache.p, reltol = cache.abstol)
268+
cache.linsolve = linres.cache
274269
@bb @. cache.v = -linres.u
275270
else
276271
@bb cache.u_cache_2 = transpose(cache.J) × cache.fu
277272
@bb @. cache.mat_tmp = cache.JᵀJ + cache.λ * cache.DᵀD
278-
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
273+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
279274
b = _vec(cache.u_cache_2), linu = _vec(cache.v), cache.p, reltol = cache.abstol)
280275
cache.linsolve = linres.cache
281276
@bb @. cache.v = -linres.u
@@ -289,7 +284,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
289284
evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2))
290285

291286
# The following lines do: cache.a = -cache.mat_tmp \ cache.fu_tmp
292-
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
287+
# NOTE: Don't pass `A`` in again, since we want to reuse the previous solve
293288
@bb cache.Jv = cache.J × cache.v
294289
@bb @. cache.fu_cache_2 = (2 / cache.h) *
295290
((cache.fu_cache_2 - cache.fu) / cache.h - cache.Jv)
@@ -301,13 +296,14 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
301296
end
302297
linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.a),
303298
cache.p, reltol = cache.abstol)
299+
cache.linsolve = linres.cache
304300
@bb @. cache.a = -linres.u
305301
else
306-
@bb cache.u_cache_2 = transpose(J) × cache.fu_cache_2
302+
@bb cache.u_cache_2 = transpose(cache.J) × cache.fu_cache_2
307303
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_cache_2),
308304
linu = _vec(cache.a), cache.p, reltol = cache.abstol)
309305
cache.linsolve = linres.cache
310-
@bb @. cache.a = -linres.du
306+
@bb @. cache.a = -linres.u
311307
end
312308

313309
cache.stats.nsolve += 2
@@ -316,8 +312,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
316312
# Require acceptable steps to satisfy the following condition.
317313
norm_v = cache.internalnorm(cache.v)
318314
if 2 * cache.internalnorm(cache.a) cache.α_geodesic * norm_v
319-
@bb @. cache.du_cache = cache.v + cache.a / 2
320-
@bb @. cache.u_cache_2 = cache.u + cache.du_cache
315+
@bb @. cache.u_cache_2 = cache.u + cache.v + cache.a / 2
321316
evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2))
322317
loss = cache.internalnorm(cache.fu_cache_2)
323318

@@ -326,7 +321,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
326321
if (1 - β)^cache.b_uphill * loss cache.loss_old
327322
# Accept step.
328323
@bb copyto!(cache.u, cache.u_cache_2)
329-
check_and_update!(cache.tc_cache_1, cache, cache.fu_cache, cache.u,
324+
check_and_update!(cache.tc_cache_1, cache, cache.fu_cache_2, cache.u,
330325
cache.u_cache)
331326
if !cache.force_stop && cache.tc_cache_2 !== nothing # For NLLS Problems
332327
@bb @. cache.fu = cache.fu_cache_2 - cache.fu

src/pseudotransient.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@ function perform_step!(cache::PseudoTransientCache{iip}) where {iip}
112112
if cache.J isa SciMLOperators.AbstractSciMLOperator
113113
A = cache.J - inv_α * I
114114
elseif setindex_trait(cache.J) === CanSetindex()
115-
idxs = diagind(cache.J)
116115
if fast_scalar_indexing(cache.J)
117116
@inbounds for i in axes(cache.J, 1)
118117
cache.J[i, i] = cache.J[i, i] - inv_α
119118
end
120119
else
120+
idxs = diagind(cache.J)
121121
@.. broadcast=false @view(cache.J[idxs])=@view(cache.J[idxs]) - inv_α
122122
end
123123
A = cache.J

src/utils.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,20 @@ end
457457
return Diagonal(max.(y.diag, x))
458458
end
459459
end
460-
@inline @views function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
461-
x_diag = x[diagind(x)]
460+
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
462461
if setindex_trait(y.diag) === CanSetindex()
463-
@. y.diag = max(y.diag, x_diag)
464-
return y
462+
if fast_scalar_indexing(y.diag)
463+
@inbounds for i in axes(x, 1)
464+
y.diag[i] = max(y.diag[i], x[i, i])
465+
end
466+
return y
467+
else
468+
idxs = diagind(x)
469+
@.. broadcast=false y.diag=max(y.diag, @view(x[idxs]))
470+
return y
471+
end
465472
else
466-
return Diagonal(max.(y.diag, x_diag))
473+
idxs = diagind(x)
474+
return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs])))
467475
end
468476
end

test/23_test_problems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ end
7373

7474
# dictionary with indices of test problems where method does not converge to small residual
7575
broken_tests = Dict(alg => Int[] for alg in alg_ops)
76-
broken_tests[alg_ops[1]] = [3, 6, 17, 21]
77-
broken_tests[alg_ops[2]] = [3, 6, 17, 21]
76+
broken_tests[alg_ops[1]] = [3, 6, 11, 17, 21]
77+
broken_tests[alg_ops[2]] = [3, 6, 11, 17, 21]
7878
broken_tests[alg_ops[3]] = [6, 11, 17, 21]
7979

8080
test_on_library(problems, dicts, alg_ops, broken_tests)

0 commit comments

Comments
 (0)