Skip to content

Commit 679fdab

Browse files
committed
Bug fix for TrustRegion when iip=true.
1 parent 9049a3c commit 679fdab

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

src/trustRegion.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,8 @@ function trust_region_step!(cache::TrustRegionCache)
309309
if r > alg.step_threshold
310310

311311
# Take the step.
312-
cache.u = u_tmp
313-
cache.fu = fu_new
312+
cache.u = copy(u_tmp)
313+
cache.fu = copy(fu_new)
314314
cache.loss = cache.loss_new
315315

316316
# Update the trust region radius.
@@ -324,7 +324,7 @@ function trust_region_step!(cache::TrustRegionCache)
324324
cache.make_new_J = false
325325
end
326326

327-
if iszero(cache.fu) || cache.internalnorm(cache.step_size) < cache.abstol
327+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
328328
cache.force_stop = true
329329
end
330330
end

test/basictests.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ u0 = [1.0, 1.0]
3434

3535
sol = benchmark_immutable(ff, cu0)
3636
@test sol.retcode === ReturnCode.Success
37-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
37+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
3838
sol = benchmark_mutable(ff, u0)
3939
@test sol.retcode === ReturnCode.Success
40-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
40+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
4141
sol = benchmark_scalar(sf, csu0)
4242
@test sol.retcode === ReturnCode.Success
43-
@test sol.u * sol.u - 2 < 1e-9
43+
@test abs(sol.u * sol.u - 2) < 1e-9
4444

4545
# @test (@ballocated benchmark_immutable(ff, cu0)) < 200
4646
# @test (@ballocated benchmark_mutable(ff, cu0)) < 200
@@ -59,7 +59,7 @@ u0 = [1.0, 1.0]
5959

6060
sol = benchmark_inplace(ffiip, u0)
6161
@test sol.retcode === ReturnCode.Success
62-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
62+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
6363

6464
u0 = [1.0, 1.0]
6565
probN = NonlinearProblem{true}(ffiip, u0)
@@ -160,13 +160,13 @@ u0 = [1.0, 1.0]
160160

161161
sol = benchmark_immutable(ff, cu0)
162162
@test sol.retcode === ReturnCode.Success
163-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
163+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
164164
sol = benchmark_mutable(ff, u0)
165165
@test sol.retcode === ReturnCode.Success
166-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
166+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
167167
sol = benchmark_scalar(sf, csu0)
168168
@test sol.retcode === ReturnCode.Success
169-
@test sol.u * sol.u - 2 < 1e-9
169+
@test abs(sol.u * sol.u - 2) < 1e-9
170170

171171
function benchmark_inplace(f, u0)
172172
probN = NonlinearProblem{true}(f, u0)
@@ -181,7 +181,7 @@ u0 = [1.0, 1.0]
181181

182182
sol = benchmark_inplace(ffiip, u0)
183183
@test sol.retcode === ReturnCode.Success
184-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
184+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
185185

186186
u0 = [1.0, 1.0]
187187
probN = NonlinearProblem{true}(ffiip, u0)
@@ -263,7 +263,7 @@ f = (u, p) -> 0.010000000000000002 .+
263263
0.0011552453009332421u .- p
264264
g = function (p)
265265
probN = NonlinearProblem{false}(f, u0, p)
266-
sol = solve(probN, TrustRegion())
266+
sol = solve(probN, TrustRegion(), abstol = 1e-10)
267267
return sol.u
268268
end
269269
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
@@ -295,7 +295,7 @@ for options in list_of_options
295295
expand_factor = options[7],
296296
max_shrink_times = options[8])
297297

298-
probN = NonlinearProblem(f, u0, p)
299-
sol = solve(probN, alg)
300-
@test all(f(u, p) .< 1e-10)
298+
probN = NonlinearProblem{false}(f, u0, p)
299+
sol = solve(probN, alg, abstol = 1e-10)
300+
@test all(abs.(f(u, p)) .< 1e-10)
301301
end

0 commit comments

Comments
 (0)