Skip to content

Commit f67ced5

Browse files
committed
avoid recomputation of GN step if TR step was rejected. Faster and avoids a bug due to mutation of Jacobian in dolinsolve.
1 parent b2b5d89 commit f67ced5

File tree

1 file changed

+25
-22
lines changed

1 file changed

+25
-22
lines changed

src/trustRegion.jl

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ end
203203
shrink_counter::Int
204204
du
205205
u_tmp
206+
u_gauss_newton
206207
u_cauchy
207208
fu_new
208209
make_new_J::Bool
@@ -229,6 +230,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
229230
linsolve_kwargs)
230231
u_tmp = zero(u)
231232
u_cauchy = zero(u)
233+
u_gauss_newton = zero(u)
232234

233235
loss_new = loss
234236
H = zero(J)
@@ -265,7 +267,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
265267
expand_factor = convert(trustType, alg.expand_factor)
266268

267269
# Parameters for the Schemes
268-
floatType = typeof(r)
269270
p1 = convert(floatType, 0.0)
270271
p2 = convert(floatType, 0.0)
271272
p3 = convert(floatType, 0.0)
@@ -321,28 +322,30 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
321322
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
322323
radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold,
323324
shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new,
324-
H, g, shrink_counter, du, u_tmp, u_cauchy, fu_new, make_new_J, r, p1, p2, p3, p4, ϵ,
325+
H, g, shrink_counter, du, u_tmp, u_gauss_newton, u_cauchy, fu_new, make_new_J, r, p1, p2, p3, p4, ϵ,
325326
NLStats(1, 0, 0, 0, 0))
326327
end
327328

328329
isinplace(::TrustRegionCache{iip}) where {iip} = iip
329330

330331
function perform_step!(cache::TrustRegionCache{true})
331-
@unpack make_new_J, J, fu, f, u, p, u_tmp, alg, linsolve = cache
332+
@unpack make_new_J, J, fu, f, u, p, u_gauss_newton, alg, linsolve = cache
332333
if cache.make_new_J
333334
jacobian!!(J, cache)
334335
mul!(cache.H, J', J)
335336
mul!(cache.g, J', fu)
336337
cache.stats.njacs += 1
337-
end
338338

339-
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
340-
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
341-
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu),
342-
linu = _vec(u_tmp),
339+
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
340+
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
341+
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu),
342+
linu = _vec(u_gauss_newton),
343343
p = p, reltol = cache.abstol)
344-
cache.linsolve = linres.cache
345-
cache.u_tmp .= -1 .* u_tmp
344+
cache.linsolve = linres.cache
345+
@. cache.u_gauss_newton = -1 * u_gauss_newton
346+
end
347+
348+
# Compute dogleg step
346349
dogleg!(cache)
347350

348351
# Compute the potentially new u
@@ -363,11 +366,10 @@ function perform_step!(cache::TrustRegionCache{false})
363366
cache.H = J' * J
364367
cache.g = J' * fu
365368
cache.stats.njacs += 1
369+
cache.u_gauss_newton = -1 .* (cache.H \ cache.g)
366370
end
367371

368-
@unpack g, H = cache
369372
# Compute the Newton step.
370-
cache.u_tmp = -1 .* (H \ g)
371373
dogleg!(cache)
372374

373375
# Compute the potentially new u
@@ -566,25 +568,26 @@ function trust_region_step!(cache::TrustRegionCache)
566568
end
567569

568570
function dogleg!(cache::TrustRegionCache{true})
569-
@unpack u_tmp, u_cauchy, trust_r = cache
571+
@unpack u_tmp, u_gauss_newton, u_cauchy, trust_r = cache
570572

571573
# Take the full Gauss-Newton step if lies within the trust region.
572-
if norm(u_tmp) trust_r
573-
cache.du .= u_tmp
574+
if norm(u_gauss_newton) trust_r
575+
cache.du .= u_gauss_newton
574576
return
575577
end
576578

577579
# Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
578580
l_grad = norm(cache.g) # length of the gradient
579581
d_cauchy = l_grad^3 / dot(cache.g, cache.H, cache.g) # distance of the cauchy point from the current iterate
580-
if d_cauchy > trust_r
582+
if d_cauchy >= trust_r
581583
@. cache.du = - (trust_r/l_grad) * cache.g # step to the end of the trust region
582584
return
583585
end
584-
586+
585587
# Take the intersection of dogled with trust region if Cauchy point lies inside the trust region
586588
@. u_cauchy = - (d_cauchy/l_grad) * cache.g # compute Cauchy point
587-
@. u_tmp -= u_cauchy # calf of the dogleg -- use u_tmp to avoid allocation
589+
@. u_tmp = u_gauss_newton - u_cauchy # calf of the dogleg -- use u_tmp to avoid allocation
590+
588591
a = dot(u_tmp, u_tmp)
589592
b = 2*dot(u_cauchy, u_tmp)
590593
c = d_cauchy^2 - trust_r^2
@@ -596,11 +599,11 @@ end
596599

597600

598601
function dogleg!(cache::TrustRegionCache{false})
599-
@unpack u_tmp, u_cauchy, trust_r = cache
602+
@unpack u_tmp, u_gauss_newton, u_cauchy, trust_r = cache
600603

601604
# Take the full Gauss-Newton step if lies within the trust region.
602-
if norm(u_tmp) trust_r
603-
cache.du = deepcopy(u_tmp)
605+
if norm(u_gauss_newton) trust_r
606+
cache.du = deepcopy(u_gauss_newton)
604607
return
605608
end
606609

@@ -614,7 +617,7 @@ function dogleg!(cache::TrustRegionCache{false})
614617

615618
# Take the intersection of dogled with trust region if Cauchy point lies inside the trust region
616619
u_cauchy = - (d_cauchy/l_grad) * cache.g # compute Cauchy point
617-
u_tmp -= u_cauchy # calf of the dogleg -- use u_tmp to avoid allocation
620+
u_tmp = u_gauss_newton - u_cauchy # calf of the dogleg
618621
a = dot(u_tmp, u_tmp)
619622
b = 2*dot(u_cauchy, u_tmp)
620623
c = d_cauchy^2 - trust_r^2

0 commit comments

Comments
 (0)