Skip to content

Commit eb3a6ff

Browse files
committed
Finalize tests
1 parent 5963ec9 commit eb3a6ff

File tree

5 files changed

+138
-238
lines changed

5 files changed

+138
-238
lines changed

src/ad.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
3030
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
3131
sol.retcode)
3232
end
33+
3334
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
3435
<:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
3536
kwargs...) where {iip, T, V, P}

test/basictests.jl

Lines changed: 122 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,21 @@ using BenchmarkTools, LinearSolve, NonlinearSolve, StaticArrays, Random, LinearA
33

44
_nameof(x) = applicable(nameof, x) ? nameof(x) : _nameof(typeof(x))
55

6+
quadratic_f(u, p) = u .* u .- p
7+
quadratic_f!(du, u, p) = (du .= u .* u .- p)
8+
quadratic_f2(u, p) = @. p[1] * u * u - p[2]
9+
10+
function newton_fails(u, p)
11+
return 0.010000000000000002 .+
12+
10.000000000000002 ./ (1 .+
13+
(0.21640425613334457 .+
14+
216.40425613334457 ./ (1 .+
15+
(0.21640425613334457 .+
16+
216.40425613334457 ./
17+
(1 .+ 0.0006250000000000001(u .^ 2.0))) .^ 2.0)) .^ 2.0) .-
18+
0.0011552453009332421u .- p
19+
end
20+
621
# --- NewtonRaphson tests ---
722

823
@testset "NewtonRaphson" begin
@@ -16,9 +31,6 @@ _nameof(x) = applicable(nameof, x) ? nameof(x) : _nameof(typeof(x))
1631
return solve(prob, NewtonRaphson(; linsolve, precs), abstol = 1e-9)
1732
end
1833

19-
quadratic_f(u, p) = u .* u .- p
20-
quadratic_f!(du, u, p) = (du .= u .* u .- p)
21-
2234
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
2335
sol = benchmark_nlsolve_oop(quadratic_f, u0)
2436
@test SciMLBase.successful_retcode(sol)
@@ -40,7 +52,7 @@ _nameof(x) = applicable(nameof, x) ? nameof(x) : _nameof(typeof(x))
4052
@test SciMLBase.successful_retcode(sol)
4153
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
4254

43-
cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0),
55+
cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
4456
NewtonRaphson(; linsolve, precs = prec), abstol = 1e-9)
4557
@test (@ballocated solve!($cache)) 64
4658
end
@@ -67,7 +79,6 @@ _nameof(x) = applicable(nameof, x) ? nameof(x) : _nameof(typeof(x))
6779
1 / (2 * sqrt(p))
6880
end
6981

70-
quadratic_f2(u, p) = @. p[1] * u * u - p[2]
7182
t = (p) -> [sqrt(p[2] / p[1])]
7283
p = [0.9, 50.0]
7384
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
@@ -113,9 +124,6 @@ end
113124
return solve(prob, TrustRegion(; radius_update_scheme); abstol = 1e-9, kwargs...)
114125
end
115126

116-
quadratic_f(u, p) = u .* u .- p
117-
quadratic_f!(du, u, p) = (du .= u .* u .- p)
118-
119127
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei,
120128
RadiusUpdateSchemes.Yuan, RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]
121129

@@ -169,7 +177,6 @@ end
169177
p; radius_update_scheme).u, p) 1 / (2 * sqrt(p))
170178
end
171179

172-
quadratic_f2(u, p) = @. p[1] * u * u - p[2]
173180
t = (p) -> [sqrt(p[2] / p[1])]
174181
p = [0.9, 50.0]
175182
@testset "[OOP] [Jacobian] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes
@@ -209,17 +216,6 @@ end
209216
end
210217

211218
# Test that `TrustRegion` passes a test that `NewtonRaphson` fails on.
212-
function newton_fails(u, p)
213-
return 0.010000000000000002 .+
214-
10.000000000000002 ./ (1 .+
215-
(0.21640425613334457 .+
216-
216.40425613334457 ./ (1 .+
217-
(0.21640425613334457 .+
218-
216.40425613334457 ./
219-
(1 .+ 0.0006250000000000001(u .^ 2.0))) .^ 2.0)) .^ 2.0) .-
220-
0.0011552453009332421u .- p
221-
end
222-
223219
@testset "Newton Raphson Fails: radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in [
224220
RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]
225221
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
@@ -251,9 +247,9 @@ end
251247
shrink_factor = options[6], expand_factor = options[7],
252248
max_shrink_times = options[8])
253249

254-
probN = NonlinearProblem{false}(f, u0, p)
250+
probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0)
255251
sol = solve(probN, alg, abstol = 1e-10)
256-
@test all(abs.(f(u, p)) .< 1e-10)
252+
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10)
257253
end
258254
end
259255

@@ -275,170 +271,107 @@ end
275271

276272
# --- LevenbergMarquardt tests ---
277273

278-
@testset "LevenbergMarquardt" begin end
279-
280-
# function benchmark_immutable(f, u0)
281-
# probN = NonlinearProblem{false}(f, u0)
282-
# solver = init(probN, LevenbergMarquardt(), abstol = 1e-9)
283-
# sol = solve!(solver)
284-
# end
285-
286-
# function benchmark_mutable(f, u0)
287-
# probN = NonlinearProblem{false}(f, u0)
288-
# solver = init(probN, LevenbergMarquardt(), abstol = 1e-9)
289-
# sol = solve!(solver)
290-
# end
291-
292-
# function benchmark_scalar(f, u0)
293-
# probN = NonlinearProblem{false}(f, u0)
294-
# sol = (solve(probN, LevenbergMarquardt(), abstol = 1e-9))
295-
# end
296-
297-
# function ff(u, p)
298-
# u .* u .- 2
299-
# end
300-
301-
# function sf(u, p)
302-
# u * u - 2
303-
# end
304-
# u0 = [1.0, 1.0]
305-
306-
# sol = benchmark_immutable(ff, cu0)
307-
# @test SciMLBase.successful_retcode(sol)
308-
# @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
309-
# sol = benchmark_mutable(ff, u0)
310-
# @test SciMLBase.successful_retcode(sol)
311-
# @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
312-
# sol = benchmark_scalar(sf, csu0)
313-
# @test SciMLBase.successful_retcode(sol)
314-
# @test abs(sol.u * sol.u - 2) < 1e-9
315-
316-
# function benchmark_inplace(f, u0)
317-
# probN = NonlinearProblem{true}(f, u0)
318-
# solver = init(probN, LevenbergMarquardt(), abstol = 1e-9)
319-
# sol = solve!(solver)
320-
# end
321-
322-
# function ffiip(du, u, p)
323-
# du .= u .* u .- 2
324-
# end
325-
# u0 = [1.0, 1.0]
326-
327-
# sol = benchmark_inplace(ffiip, u0)
328-
# @test SciMLBase.successful_retcode(sol)
329-
# @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
330-
331-
# u0 = [1.0, 1.0]
332-
# probN = NonlinearProblem{true}(ffiip, u0)
333-
# solver = init(probN, LevenbergMarquardt(), abstol = 1e-9)
334-
# @test (@ballocated solve!(solver)) < 120
335-
336-
# # AD Tests
337-
# using ForwardDiff
338-
339-
# # Immutable
340-
# f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
341-
342-
# g = function (p)
343-
# probN = NonlinearProblem{false}(f, csu0, p)
344-
# sol = solve(probN, LevenbergMarquardt(), abstol = 1e-9)
345-
# return sol.u[end]
346-
# end
347-
348-
# for p in 1.1:0.1:100.0
349-
# @test g(p) ≈ sqrt(p)
350-
# @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
351-
# end
352-
353-
# # Scalar
354-
# f, u0 = (u, p) -> u * u - p, 1.0
355-
356-
# g = function (p)
357-
# probN = NonlinearProblem{false}(f, oftype(p, u0), p)
358-
# sol = solve(probN, LevenbergMarquardt(), abstol = 1e-10)
359-
# return sol.u
360-
# end
361-
362-
# @test ForwardDiff.derivative(g, 3.0) ≈ 1 / (2 * sqrt(3.0))
363-
364-
# for p in 1.1:0.1:100.0
365-
# @test g(p) ≈ sqrt(p)
366-
# @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
367-
# end
368-
369-
# f = (u, p) -> p[1] * u * u - p[2]
370-
# t = (p) -> [sqrt(p[2] / p[1])]
371-
# p = [0.9, 50.0]
372-
# gnewton = function (p)
373-
# probN = NonlinearProblem{false}(f, 0.5, p)
374-
# sol = solve(probN, LevenbergMarquardt())
375-
# return [sol.u]
376-
# end
377-
# @test gnewton(p) ≈ [sqrt(p[2] / p[1])]
378-
# @test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p)
379-
380-
# # Error Checks
381-
# f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
382-
# probN = NonlinearProblem(f, u0)
383-
384-
# @test solve(probN, LevenbergMarquardt()).u[end] ≈ sqrt(2.0)
385-
# @test solve(probN, LevenbergMarquardt(; autodiff = false)).u[end] ≈ sqrt(2.0)
386-
387-
# for u0 in [1.0, [1, 1.0]]
388-
# local f, probN, sol
389-
# f = (u, p) -> u .* u .- 2.0
390-
# probN = NonlinearProblem(f, u0)
391-
# sol = sqrt(2) * u0
392-
393-
# @test solve(probN, LevenbergMarquardt()).u ≈ sol
394-
# @test solve(probN, LevenbergMarquardt()).u ≈ sol
395-
# @test solve(probN, LevenbergMarquardt(; autodiff = false)).u ≈ sol
396-
# end
397-
398-
# # Test that `LevenbergMarquardt` passes a test that `NewtonRaphson` fails on.
399-
# u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
400-
# global g, f
401-
# f = (u, p) -> 0.010000000000000002 .+
402-
# 10.000000000000002 ./ (1 .+
403-
# (0.21640425613334457 .+
404-
# 216.40425613334457 ./ (1 .+
405-
# (0.21640425613334457 .+
406-
# 216.40425613334457 ./
407-
# (1 .+ 0.0006250000000000001(u .^ 2.0))) .^ 2.0)) .^ 2.0) .-
408-
# 0.0011552453009332421u .- p
409-
# g = function (p)
410-
# probN = NonlinearProblem{false}(f, u0, p)
411-
# sol = solve(probN, LevenbergMarquardt(), abstol = 1e-10)
412-
# return sol.u
413-
# end
414-
# p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
415-
# u = g(p)
416-
# f(u, p)
417-
# @test all(abs.(f(u, p)) .< 1e-10)
418-
419-
# # # Test kwars in `LevenbergMarquardt`
420-
# damping_initial = [0.5, 2.0, 5.0]
421-
# damping_increase_factor = [1.5, 3.0, 10.0]
422-
# damping_decrease_factor = [2, 5, 10]
423-
# finite_diff_step_geodesic = [0.02, 0.2, 0.3]
424-
# α_geodesic = [0.6, 0.8, 0.9]
425-
# b_uphill = [0, 1, 2]
426-
# min_damping_D = [1e-12, 1e-9, 1e-4]
427-
428-
# list_of_options = zip(damping_initial, damping_increase_factor, damping_decrease_factor,
429-
# finite_diff_step_geodesic, α_geodesic, b_uphill,
430-
# min_damping_D)
431-
# for options in list_of_options
432-
# local probN, sol, alg
433-
# alg = LevenbergMarquardt(damping_initial = options[1],
434-
# damping_increase_factor = options[2],
435-
# damping_decrease_factor = options[3],
436-
# finite_diff_step_geodesic = options[4],
437-
# α_geodesic = options[5],
438-
# b_uphill = options[6],
439-
# min_damping_D = options[7])
440-
441-
# probN = NonlinearProblem{false}(f, u0, p)
442-
# sol = solve(probN, alg, abstol = 1e-10)
443-
# @test all(abs.(f(u, p)) .< 1e-10)
444-
# end
274+
@testset "LevenbergMarquardt" begin
275+
function benchmark_nlsolve_oop(f, u0, p = 2.0)
276+
prob = NonlinearProblem{false}(f, u0, p)
277+
return solve(prob, LevenbergMarquardt(), abstol = 1e-9)
278+
end
279+
280+
function benchmark_nlsolve_iip(f, u0, p = 2.0)
281+
prob = NonlinearProblem{true}(f, u0, p)
282+
return solve(prob, LevenbergMarquardt(), abstol = 1e-9)
283+
end
284+
285+
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
286+
sol = benchmark_nlsolve_oop(quadratic_f, u0)
287+
@test SciMLBase.successful_retcode(sol)
288+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
289+
290+
cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), LevenbergMarquardt(),
291+
abstol = 1e-9)
292+
@test (@ballocated solve!($cache)) < 200
293+
end
294+
295+
@testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],)
296+
sol = benchmark_nlsolve_iip(quadratic_f!, u0)
297+
@test SciMLBase.successful_retcode(sol)
298+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
299+
300+
cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0), LevenbergMarquardt(),
301+
abstol = 1e-9)
302+
@test (@ballocated solve!($cache)) 64
303+
end
304+
305+
# FIXME: Even the previous tests were broken, but due to a typo in the tests they
306+
# accidentally passed
307+
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
308+
@test begin
309+
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
310+
res_true = sqrt(p)
311+
all(res.u .≈ res_true)
312+
end
313+
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
314+
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
315+
end
316+
317+
@testset "[OOP] [Scalar AD] p: $(p)" for p in 1.0:0.1:100.0
318+
@test begin
319+
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
320+
res_true = sqrt(p)
321+
res.u res_true
322+
end
323+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)
324+
1 / (2 * sqrt(p))
325+
end
326+
327+
t = (p) -> [sqrt(p[2] / p[1])]
328+
p = [0.9, 50.0]
329+
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
330+
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], p)
331+
ForwardDiff.jacobian(t, p)
332+
333+
probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
334+
@testset "ADType: $(autodiff) u0: $(u0)" for autodiff in (false, true,
335+
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
336+
AutoSparseZygote(),
337+
AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
338+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
339+
@test all(solve(probN, LevenbergMarquardt(; autodiff)).u .≈ sqrt(2.0))
340+
end
341+
342+
# Test that `LevenbergMarquardt` passes a test that `NewtonRaphson` fails on.
343+
@testset "Newton Raphson Fails" begin
344+
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
345+
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
346+
sol = benchmark_nlsolve_oop(newton_fails, u0, p)
347+
@test SciMLBase.successful_retcode(sol)
348+
@test all(abs.(newton_fails(sol.u, p)) .< 1e-9)
349+
end
350+
351+
# Test kwargs in `LevenbergMarquardt`
352+
@testset "Keyword Arguments" begin
353+
damping_initial = [0.5, 2.0, 5.0]
354+
damping_increase_factor = [1.5, 3.0, 10.0]
355+
damping_decrease_factor = Float64[2, 5, 10]
356+
finite_diff_step_geodesic = [0.02, 0.2, 0.3]
357+
α_geodesic = [0.6, 0.8, 0.9]
358+
b_uphill = Float64[0, 1, 2]
359+
min_damping_D = [1e-12, 1e-9, 1e-4]
360+
361+
list_of_options = zip(damping_initial, damping_increase_factor,
362+
damping_decrease_factor, finite_diff_step_geodesic, α_geodesic, b_uphill,
363+
min_damping_D)
364+
for options in list_of_options
365+
local probN, sol, alg
366+
alg = LevenbergMarquardt(damping_initial = options[1],
367+
damping_increase_factor = options[2],
368+
damping_decrease_factor = options[3],
369+
finite_diff_step_geodesic = options[4], α_geodesic = options[5],
370+
b_uphill = options[6], min_damping_D = options[7])
371+
372+
probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0)
373+
sol = solve(probN, alg, abstol = 1e-10)
374+
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10)
375+
end
376+
end
377+
end

0 commit comments

Comments
 (0)