Skip to content

Commit cf5c17e

Browse files
committed
added tests for fan
1 parent 7971372 commit cf5c17e

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

test/basictests.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ function sf(u, p=nothing)
189189
end
190190

191191
u0 = [1.0, 1.0]
192-
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan]
192+
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan,
193+
RadiusUpdateSchemes.Fan]
193194

194195
for radius_update_scheme in radius_update_schemes
195196
sol = benchmark_immutable(ff, cu0, radius_update_scheme)
@@ -255,7 +256,6 @@ for p in 1.1:0.1:100.0
255256
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
256257
end
257258

258-
## FAIL BECAUSE JVP CANNOT ACCEPT PARAMETERS IN FUNCTIONS
259259
g = function (p)
260260
probN = NonlinearProblem{false}(f, csu0, p)
261261
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan), abstol = 1e-9)
@@ -267,6 +267,17 @@ for p in 1.1:0.1:100.0
267267
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
268268
end
269269

270+
g = function (p)
271+
probN = NonlinearProblem{false}(f, csu0, p)
272+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-9)
273+
return sol.u[end]
274+
end
275+
276+
for p in 1.1:0.1:100.0
277+
@test g(p) sqrt(p)
278+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
279+
end
280+
270281
# Scalar
271282
f, u0 = (u, p) -> u * u - p, 1.0
272283

@@ -309,6 +320,19 @@ for p in 1.1:0.1:100.0
309320
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
310321
end
311322

323+
g = function (p)
324+
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
325+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-10)
326+
return sol.u
327+
end
328+
329+
@test ForwardDiff.derivative(g, 3.0) 1 / (2 * sqrt(3.0))
330+
331+
for p in 1.1:0.1:100.0
332+
@test g(p) sqrt(p)
333+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
334+
end
335+
312336
f = (u, p) -> p[1] * u * u - p[2]
313337
t = (p) -> [sqrt(p[2] / p[1])]
314338
p = [0.9, 50.0]

0 commit comments

Comments
 (0)