Skip to content

Commit c698d16

Browse files
Merge pull request #172 from yash2798/ys/tr_jvp
Including jvp for radius update schemes
2 parents c9ffc07 + 50f266d commit c698d16

File tree

3 files changed

+132
-38
lines changed

3 files changed

+132
-38
lines changed

src/jacobian.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,5 @@ function jacobian_autodiff(f, x::AbstractArray, nonlinfun, alg)
144144
jac_prototype = jac_prototype, chunksize = chunk_size),
145145
num_of_chunks)
146146
end
147+
148+

src/trustRegion.jl

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ EnumX.@enumx RadiusUpdateSchemes begin
8585
Hei
8686
Yuan
8787
Bastin
88+
Fan
8889
end
8990

9091
struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} <:
@@ -234,7 +235,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
234235
args...;
235236
alias_u0 = false,
236237
maxiters = 1000,
237-
abstol = 1e-6,
238+
abstol = 1e-8,
238239
internalnorm = DEFAULT_NORM,
239240
kwargs...) where {uType, iip}
240241
if alias_u0
@@ -301,7 +302,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
301302
p2 = convert(eltype(u), 1/6) # c5
302303
p3 = convert(eltype(u), 6.0) # c6
303304
p4 = convert(eltype(u), 0.0)
304-
end
305+
end
305306

306307
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
307308
1, false, maxiters, internalnorm,
@@ -402,10 +403,12 @@ function trust_region_step!(cache::TrustRegionCache)
402403
@unpack shrink_threshold, p1, p2, p3, p4 = cache
403404
if rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) < cache.trust_r
404405
cache.shrink_counter += 1
406+
else
407+
cache.shrink_counter = 0
405408
end
406-
cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) # parameters to be defined
409+
cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size)
407410

408-
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
411+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ
409412
cache.force_stop = true
410413
end
411414

@@ -416,19 +419,20 @@ function trust_region_step!(cache::TrustRegionCache)
416419
cache.shrink_counter += 1
417420
elseif r >= cache.expand_threshold && cache.internalnorm(step_size) > cache.trust_r / 2
418421
cache.p1 = cache.p3 * cache.p1
422+
cache.shrink_counter = 0
419423
end
420-
@unpack p1, fu, f, J = cache
421-
#cache.trust_r = p1 * cache.internalnorm(jacobian!(J, cache) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?
422-
424+
423425
if r > cache.step_threshold
424426
take_step!(cache)
425427
cache.loss = cache.loss_new
426428
cache.make_new_J = true
427429
else
428430
cache.make_new_J = false
429431
end
430-
431-
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
432+
433+
@unpack p1= cache
434+
cache.trust_r = p1 * cache.internalnorm(jvp!(cache))
435+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ
432436
cache.force_stop = true
433437
end
434438

@@ -441,7 +445,7 @@ function dogleg!(cache::TrustRegionCache)
441445

442446
# Test if the full step is within the trust region.
443447
if norm(u_tmp) trust_r
444-
cache.step_size = u_tmp
448+
cache.step_size = deepcopy(u_tmp)
445449
return
446450
end
447451

@@ -473,6 +477,23 @@ function take_step!(cache::TrustRegionCache{false})
473477
cache.fu = cache.fu_new
474478
end
475479

480+
function jvp!(cache::TrustRegionCache{false})
481+
@unpack f, u, fu, p = cache
482+
if isa(u, Number)
483+
return value_derivative(x -> f(x, p), u)
484+
end
485+
return auto_jacvec(x -> f(x, p), u, fu)
486+
end
487+
488+
function jvp!(cache::TrustRegionCache{true})
489+
@unpack g, f, u, fu, p = cache
490+
if isa(u, Number)
491+
return value_derivative(x -> f(x, p), u)
492+
end
493+
return auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
494+
g
495+
end
496+
476497
function SciMLBase.solve!(cache::TrustRegionCache)
477498
while !cache.force_stop && cache.iter < cache.maxiters &&
478499
cache.shrink_counter < cache.alg.max_shrink_times

test/basictests.jl

Lines changed: 99 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -163,61 +163,69 @@ end
163163

164164
# --- TrustRegion tests ---
165165

166-
function benchmark_immutable(f, u0)
166+
function benchmark_immutable(f, u0, radius_update_scheme)
167167
probN = NonlinearProblem{false}(f, u0)
168-
solver = init(probN, TrustRegion(), abstol = 1e-9)
168+
solver = init(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9)
169169
sol = solve!(solver)
170170
end
171171

172-
function benchmark_mutable(f, u0)
172+
function benchmark_mutable(f, u0, radius_update_scheme)
173173
probN = NonlinearProblem{false}(f, u0)
174-
solver = init(probN, TrustRegion(), abstol = 1e-9)
174+
solver = init(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9)
175175
sol = solve!(solver)
176176
end
177177

178-
function benchmark_scalar(f, u0)
178+
function benchmark_scalar(f, u0, radius_update_scheme)
179179
probN = NonlinearProblem{false}(f, u0)
180-
sol = (solve(probN, TrustRegion(), abstol = 1e-9))
180+
sol = (solve(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9))
181181
end
182182

183-
function ff(u, p)
183+
function ff(u, p=nothing)
184184
u .* u .- 2
185185
end
186186

187-
function sf(u, p)
187+
function sf(u, p=nothing)
188188
u * u - 2
189189
end
190+
190191
u0 = [1.0, 1.0]
192+
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan]
191193

192-
sol = benchmark_immutable(ff, cu0)
193-
@test sol.retcode === ReturnCode.Success
194-
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
195-
sol = benchmark_mutable(ff, u0)
196-
@test sol.retcode === ReturnCode.Success
197-
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
198-
sol = benchmark_scalar(sf, csu0)
199-
@test sol.retcode === ReturnCode.Success
200-
@test abs(sol.u * sol.u - 2) < 1e-9
194+
for radius_update_scheme in radius_update_schemes
195+
sol = benchmark_immutable(ff, cu0, radius_update_scheme)
196+
@test sol.retcode === ReturnCode.Success
197+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
198+
sol = benchmark_mutable(ff, u0, radius_update_scheme)
199+
@test sol.retcode === ReturnCode.Success
200+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
201+
sol = benchmark_scalar(sf, csu0, radius_update_scheme)
202+
@test sol.retcode === ReturnCode.Success
203+
@test abs(sol.u * sol.u - 2) < 1e-9
204+
end
201205

202-
function benchmark_inplace(f, u0)
206+
207+
function benchmark_inplace(f, u0, radius_update_scheme)
203208
probN = NonlinearProblem{true}(f, u0)
204-
solver = init(probN, TrustRegion(), abstol = 1e-9)
209+
solver = init(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9)
205210
sol = solve!(solver)
206211
end
207212

208-
function ffiip(du, u, p)
213+
function ffiip(du, u, p=nothing)
209214
du .= u .* u .- 2
210215
end
211216
u0 = [1.0, 1.0]
212217

213-
sol = benchmark_inplace(ffiip, u0)
214-
@test sol.retcode === ReturnCode.Success
215-
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
218+
for radius_update_scheme in radius_update_schemes
219+
sol = benchmark_inplace(ffiip, u0, radius_update_scheme)
220+
@test sol.retcode === ReturnCode.Success
221+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
222+
end
216223

217-
u0 = [1.0, 1.0]
218-
probN = NonlinearProblem{true}(ffiip, u0)
219-
solver = init(probN, TrustRegion(), abstol = 1e-9)
220-
@test (@ballocated solve!(solver)) < 200
224+
for radius_update_scheme in radius_update_schemes
225+
probN = NonlinearProblem{true}(ffiip, u0)
226+
solver = init(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9)
227+
@test (@ballocated solve!(solver)) < 200
228+
end
221229

222230
# AD Tests
223231
using ForwardDiff
@@ -236,6 +244,29 @@ for p in 1.1:0.1:100.0
236244
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
237245
end
238246

247+
g = function (p)
248+
probN = NonlinearProblem{false}(f, csu0, p)
249+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Hei), abstol = 1e-9)
250+
return sol.u[end]
251+
end
252+
253+
for p in 1.1:0.1:100.0
254+
@test g(p) sqrt(p)
255+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
256+
end
257+
258+
## FAIL BECAUSE JVP CANNOT ACCEPT PARAMETERS IN FUNCTIONS
259+
g = function (p)
260+
probN = NonlinearProblem{false}(f, csu0, p)
261+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan), abstol = 1e-9)
262+
return sol.u[end]
263+
end
264+
265+
for p in 1.1:0.1:100.0
266+
@test g(p) sqrt(p)
267+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
268+
end
269+
239270
# Scalar
240271
f, u0 = (u, p) -> u * u - p, 1.0
241272

@@ -252,6 +283,32 @@ for p in 1.1:0.1:100.0
252283
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
253284
end
254285

286+
g = function (p)
287+
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
288+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Hei), abstol = 1e-10)
289+
return sol.u
290+
end
291+
292+
@test ForwardDiff.derivative(g, 3.0) 1 / (2 * sqrt(3.0))
293+
294+
for p in 1.1:0.1:100.0
295+
@test g(p) sqrt(p)
296+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
297+
end
298+
299+
g = function (p)
300+
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
301+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan), abstol = 1e-10)
302+
return sol.u
303+
end
304+
305+
@test ForwardDiff.derivative(g, 3.0) 1 / (2 * sqrt(3.0))
306+
307+
for p in 1.1:0.1:100.0
308+
@test g(p) sqrt(p)
309+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
310+
end
311+
255312
f = (u, p) -> p[1] * u * u - p[2]
256313
t = (p) -> [sqrt(p[2] / p[1])]
257314
p = [0.9, 50.0]
@@ -263,6 +320,14 @@ end
263320
@test gnewton(p) [sqrt(p[2] / p[1])]
264321
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
265322

323+
gnewton = function (p)
324+
probN = NonlinearProblem{false}(f, 0.5, p)
325+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Hei))
326+
return [sol.u]
327+
end
328+
@test gnewton(p) [sqrt(p[2] / p[1])]
329+
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
330+
266331
# Iterator interface
267332
f = (u, p) -> u * u - p
268333
g = function (p_range)
@@ -295,12 +360,18 @@ p = range(0.01, 2, length = 200)
295360
@test g(p) sqrt.(p)
296361

297362
# Error Checks
298-
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
363+
f, u0 = (u, p) -> u .* u .- 2, @SVector[1.0, 1.0]
299364
probN = NonlinearProblem(f, u0)
300365

301366
@test solve(probN, TrustRegion()).u[end] sqrt(2.0)
302367
@test solve(probN, TrustRegion(; autodiff = false)).u[end] sqrt(2.0)
303368

369+
@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Hei)).u[end] sqrt(2.0)
370+
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Hei, autodiff = false)).u[end] sqrt(2.0)
371+
372+
@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan)).u[end] sqrt(2.0)
373+
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Yuan, autodiff = false)).u[end] sqrt(2.0)
374+
304375
for u0 in [1.0, [1, 1.0]]
305376
local f, probN, sol
306377
f = (u, p) -> u .* u .- 2.0

0 commit comments

Comments
 (0)