Skip to content

Commit 05a3e99

Browse files
committed
jvp init value + some tests
1 parent bec197d commit 05a3e99

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/trustRegion.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,15 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
304304
p3 = convert(eltype(u), 6.0) # c6
305305
p4 = convert(eltype(u), 0.0)
306306
if iip
307-
J = ForwardDiff.jacobian(f, fu, u)
307+
auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
308308
else
309-
J = ForwardDiff.jacobian(f, u)
309+
if isa(u, Number)
310+
g = ForwardDiff.derivative(x -> f(x, p), u)
311+
else
312+
g = auto_jacvec(x -> f(x, p), u, fu)
313+
end
310314
end
311-
initial_trust_radius = convert(eltype(u), p1 * norm(J * fu))
315+
initial_trust_radius = convert(eltype(u), p1 * norm(g))
312316
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
313317
step_threshold = convert(eltype(u), 0.0001)
314318
shrink_threshold = convert(eltype(u), 0.25)
@@ -527,7 +531,7 @@ function jvp!(cache::TrustRegionCache{true})
527531
if isa(u, Number)
528532
return value_derivative(x -> f(x, p), u)
529533
end
530-
return auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
534+
auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
531535
g
532536
end
533537

test/basictests.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,22 @@ end
352352
@test gnewton(p) [sqrt(p[2] / p[1])]
353353
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
354354

355+
gnewton = function (p)
356+
probN = NonlinearProblem{false}(f, 0.5, p)
357+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan))
358+
return [sol.u]
359+
end
360+
@test gnewton(p) [sqrt(p[2] / p[1])]
361+
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
362+
363+
gnewton = function (p)
364+
probN = NonlinearProblem{false}(f, 0.5, p)
365+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan))
366+
return [sol.u]
367+
end
368+
@test gnewton(p) [sqrt(p[2] / p[1])]
369+
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
370+
355371
# Iterator interface
356372
f = (u, p) -> u * u - p
357373
g = function (p_range)
@@ -396,6 +412,9 @@ probN = NonlinearProblem(f, u0)
396412
@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan)).u[end] sqrt(2.0)
397413
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Yuan, autodiff = false)).u[end] sqrt(2.0)
398414

415+
@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan)).u[end] sqrt(2.0)
416+
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff = false)).u[end] sqrt(2.0)
417+
399418
for u0 in [1.0, [1, 1.0]]
400419
local f, probN, sol
401420
f = (u, p) -> u .* u .- 2.0
@@ -428,6 +447,16 @@ u = g(p)
428447
f(u, p)
429448
@test all(abs.(f(u, p)) .< 1e-10)
430449

450+
g = function (p)
451+
probN = NonlinearProblem{false}(f, u0, p)
452+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-10)
453+
return sol.u
454+
end
455+
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
456+
u = g(p)
457+
f(u, p)
458+
@test all(abs.(f(u, p)) .< 1e-10)
459+
431460
# Test kwars in `TrustRegion`
432461
max_trust_radius = [10.0, 100.0, 1000.0]
433462
initial_trust_radius = [10.0, 1.0, 0.1]

0 commit comments

Comments
 (0)