Skip to content

Commit c8f7283

Browse files
committed
Patch tracing and LM
1 parent 28e39dd commit c8f7283

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

src/levenberg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
273273
b = cache.rhs_tmp, linu = _vec(cache.v), cache.p, reltol = cache.abstol)
274274
@bb @. cache.v = -linres.u
275275
else
276-
@bb cache.u_cache_2 = transpose(J) × cache.fu
276+
@bb cache.u_cache_2 = transpose(cache.J) × cache.fu
277277
@bb @. cache.mat_tmp = cache.JᵀJ + cache.λ * cache.DᵀD
278278
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
279279
b = _vec(cache.u_cache_2), linu = _vec(cache.v), cache.p, reltol = cache.abstol)
@@ -288,7 +288,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
288288
@bb @. cache.u_cache_2 = cache.u + cache.h * cache.v
289289
evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2))
290290

291-
# The following lines do: cache.a = -J \ cache.fu_tmp
291+
# The following lines do: cache.a = -cache.mat_tmp \ cache.fu_tmp
292292
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
293293
@bb cache.Jv = cache.J × cache.v
294294
@bb @. cache.fu_cache_2 = (2 / cache.h) *
@@ -332,7 +332,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
332332
@bb @. cache.fu = cache.fu_cache_2 - cache.fu
333333
check_and_update!(cache.tc_cache_2, cache, cache.fu, cache.u, cache.u_cache)
334334
end
335-
@bb copyto!(cache.fu_cache, cache.fu_cache_2)
335+
@bb copyto!(cache.fu, cache.fu_cache_2)
336336
@bb copyto!(cache.v_cache, cache.v)
337337
cache.norm_v_old = norm_v
338338
cache.loss_old = loss

src/trace.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ function update_trace!(trace::NonlinearSolveTrace{ShT, StT}, iter, u, fu, J, δu
209209
return trace
210210
end
211211

212-
show_now = ShT && (iter % trace.trace_level.print_frequency == 1)
213-
store_now = StT && (iter % trace.trace_level.store_frequency == 1)
212+
show_now = ShT && (mod1(iter, trace.trace_level.print_frequency) == 1)
213+
store_now = StT && (mod1(iter, trace.trace_level.store_frequency) == 1)
214214
(show_now || store_now) && (entry = __trace_entry(trace.trace_level, iter, u, fu, J,
215215
δu, α))
216216
store_now && push!(trace.history, entry)
@@ -230,8 +230,8 @@ function update_trace_with_invJ!(trace::NonlinearSolveTrace{ShT, StT}, iter, u,
230230
return trace
231231
end
232232

233-
show_now = ShT && (iter % trace.trace_level.print_frequency == 1)
234-
store_now = StT && (iter % trace.trace_level.store_frequency == 1)
233+
show_now = ShT && (mod1(iter, trace.trace_level.print_frequency) == 1)
234+
store_now = StT && (mod1(iter, trace.trace_level.store_frequency) == 1)
235235
if show_now || store_now
236236
J_ = trace.trace_level isa TraceMinimal ? J : inv(J)
237237
entry = __trace_entry(trace.trace_level, iter, u, fu, J_, δu, α)

src/utils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,15 +448,16 @@ function __sum_JᵀJ!!(y, J)
448448
end
449449
end
450450

451-
function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector)
451+
@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x)
452+
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector)
452453
if setindex_trait(y.diag) === CanSetindex()
453454
@. y.diag = max(y.diag, x)
454455
return y
455456
else
456457
return Diagonal(max.(y.diag, x))
457458
end
458459
end
459-
@views function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
460+
@inline @views function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
460461
x_diag = x[diagind(x)]
461462
if setindex_trait(y.diag) === CanSetindex()
462463
@. y.diag = max(y.diag, x_diag)

0 commit comments

Comments
 (0)