Skip to content

Commit c737bd2

Browse files
Merge pull request #125 from CCsimon123/master
Bug fix for TrustRegion when iip=true.
2 parents 9049a3c + 69350aa commit c737bd2

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

src/trustRegion.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
242242
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
243243
1, false, maxiters, internalnorm,
244244
ReturnCode.Default, abstol, prob, initial_trust_radius,
245-
max_trust_radius, loss, loss, H, fu, 0, u, u_tmp, fu, true,
245+
max_trust_radius, loss, loss, H, zero(fu), 0, zero(u),
246+
u_tmp, zero(fu), true,
246247
loss)
247248
end
248249

@@ -307,10 +308,7 @@ function trust_region_step!(cache::TrustRegionCache)
307308
cache.shrink_counter = 0
308309
end
309310
if r > alg.step_threshold
310-
311-
# Take the step.
312-
cache.u = u_tmp
313-
cache.fu = fu_new
311+
take_step!(cache)
314312
cache.loss = cache.loss_new
315313

316314
# Update the trust region radius.
@@ -324,7 +322,7 @@ function trust_region_step!(cache::TrustRegionCache)
324322
cache.make_new_J = false
325323
end
326324

327-
if iszero(cache.fu) || cache.internalnorm(cache.step_size) < cache.abstol
325+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
328326
cache.force_stop = true
329327
end
330328
end
@@ -356,6 +354,16 @@ function dogleg!(cache::TrustRegionCache)
356354
cache.step_size = δsd + τ * N_sd
357355
end
358356

357+
function take_step!(cache::TrustRegionCache{true})
358+
cache.u .= cache.u_tmp
359+
cache.fu .= cache.fu_new
360+
end
361+
362+
function take_step!(cache::TrustRegionCache{false})
363+
cache.u = cache.u_tmp
364+
cache.fu = cache.fu_new
365+
end
366+
359367
function SciMLBase.solve!(cache::TrustRegionCache)
360368
while !cache.force_stop && cache.iter < cache.maxiters &&
361369
cache.shrink_counter < cache.alg.max_shrink_times

test/basictests.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ u0 = [1.0, 1.0]
3434

3535
sol = benchmark_immutable(ff, cu0)
3636
@test sol.retcode === ReturnCode.Success
37-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
37+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
3838
sol = benchmark_mutable(ff, u0)
3939
@test sol.retcode === ReturnCode.Success
40-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
40+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
4141
sol = benchmark_scalar(sf, csu0)
4242
@test sol.retcode === ReturnCode.Success
43-
@test sol.u * sol.u - 2 < 1e-9
43+
@test abs(sol.u * sol.u - 2) < 1e-9
4444

4545
# @test (@ballocated benchmark_immutable(ff, cu0)) < 200
4646
# @test (@ballocated benchmark_mutable(ff, cu0)) < 200
@@ -59,7 +59,7 @@ u0 = [1.0, 1.0]
5959

6060
sol = benchmark_inplace(ffiip, u0)
6161
@test sol.retcode === ReturnCode.Success
62-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
62+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
6363

6464
u0 = [1.0, 1.0]
6565
probN = NonlinearProblem{true}(ffiip, u0)
@@ -160,13 +160,13 @@ u0 = [1.0, 1.0]
160160

161161
sol = benchmark_immutable(ff, cu0)
162162
@test sol.retcode === ReturnCode.Success
163-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
163+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
164164
sol = benchmark_mutable(ff, u0)
165165
@test sol.retcode === ReturnCode.Success
166-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
166+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
167167
sol = benchmark_scalar(sf, csu0)
168168
@test sol.retcode === ReturnCode.Success
169-
@test sol.u * sol.u - 2 < 1e-9
169+
@test abs(sol.u * sol.u - 2) < 1e-9
170170

171171
function benchmark_inplace(f, u0)
172172
probN = NonlinearProblem{true}(f, u0)
@@ -181,7 +181,7 @@ u0 = [1.0, 1.0]
181181

182182
sol = benchmark_inplace(ffiip, u0)
183183
@test sol.retcode === ReturnCode.Success
184-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
184+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
185185

186186
u0 = [1.0, 1.0]
187187
probN = NonlinearProblem{true}(ffiip, u0)
@@ -263,7 +263,7 @@ f = (u, p) -> 0.010000000000000002 .+
263263
0.0011552453009332421u .- p
264264
g = function (p)
265265
probN = NonlinearProblem{false}(f, u0, p)
266-
sol = solve(probN, TrustRegion())
266+
sol = solve(probN, TrustRegion(), abstol = 1e-10)
267267
return sol.u
268268
end
269269
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
@@ -295,7 +295,7 @@ for options in list_of_options
295295
expand_factor = options[7],
296296
max_shrink_times = options[8])
297297

298-
probN = NonlinearProblem(f, u0, p)
299-
sol = solve(probN, alg)
300-
@test all(f(u, p) .< 1e-10)
298+
probN = NonlinearProblem{false}(f, u0, p)
299+
sol = solve(probN, alg, abstol = 1e-10)
300+
@test all(abs.(f(u, p)) .< 1e-10)
301301
end

0 commit comments

Comments
 (0)