|
| 1 | +@testsnippet RootfindTestSnippet begin |
| 2 | + using StaticArrays, Random, LinearAlgebra, ForwardDiff, NonlinearSolveBase, SciMLBase |
| 3 | + using PolyesterForwardDiff, Enzyme, ReverseDiff |
1 | 4 |
|
| 5 | + quadratic_f(u, p) = u .* u .- p |
| 6 | + quadratic_f!(du, u, p) = (du .= u .* u .- p) |
| 7 | + |
| 8 | + function newton_fails(u, p) |
| 9 | + return 0.010000000000000002 .+ |
| 10 | + 10.000000000000002 ./ (1 .+ |
| 11 | + (0.21640425613334457 .+ |
| 12 | + 216.40425613334457 ./ (1 .+ |
| 13 | + (0.21640425613334457 .+ |
| 14 | + 216.40425613334457 ./ (1 .+ 0.0006250000000000001(u .^ 2.0))) .^ 2.0)) .^ |
| 15 | + 2.0) .- 0.0011552453009332421u .- p |
| 16 | + end |
| 17 | + |
| 18 | + const TERMINATION_CONDITIONS = [ |
| 19 | + NormTerminationMode(Base.Fix1(maximum, abs)), |
| 20 | + RelTerminationMode(), |
| 21 | + RelNormTerminationMode(Base.Fix1(maximum, abs)), |
| 22 | + RelNormSafeTerminationMode(Base.Fix1(maximum, abs)), |
| 23 | + RelNormSafeBestTerminationMode(Base.Fix1(maximum, abs)), |
| 24 | + AbsTerminationMode(), |
| 25 | + AbsNormTerminationMode(Base.Fix1(maximum, abs)), |
| 26 | + AbsNormSafeTerminationMode(Base.Fix1(maximum, abs)), |
| 27 | + AbsNormSafeBestTerminationMode(Base.Fix1(maximum, abs)) |
| 28 | + ] |
| 29 | + |
| 30 | + function run_nlsolve_oop(f::F, u0, p = 2.0; solver) where {F} |
| 31 | + return solve(NonlinearProblem{false}(f, u0, p), solver; abstol = 1e-9) |
| 32 | + end |
| 33 | + function run_nlsolve_iip(f!::F, u0, p = 2.0; solver) where {F} |
| 34 | + return solve(NonlinearProblem{true}(f!, u0, p), solver; abstol = 1e-9) |
| 35 | + end |
| 36 | +end |
| 37 | + |
| 38 | +@testitem "First Order Methods" setup=[RootfindTestSnippet] tags=[:core] begin |
| 39 | + @testset for alg in ( |
| 40 | + SimpleNewtonRaphson, |
| 41 | + SimpleTrustRegion, |
| 42 | + (; kwargs...) -> SimpleTrustRegion(; kwargs..., nlsolve_update_rule = Val(true)) |
| 43 | + ) |
| 44 | + @testset for autodiff in ( |
| 45 | + AutoForwardDiff(), |
| 46 | + AutoFiniteDiff(), |
| 47 | + AutoReverseDiff(), |
| 48 | + AutoEnzyme(), |
| 49 | + nothing |
| 50 | + ) |
| 51 | + @testset "[OOP] u0: $(typeof(u0))" for u0 in ( |
| 52 | + [1.0, 1.0], @SVector[1.0, 1.0], 1.0) |
| 53 | + sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff)) |
| 54 | + @test SciMLBase.successful_retcode(sol) |
| 55 | + @test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9 |
| 56 | + end |
| 57 | + |
| 58 | + @testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],) |
| 59 | + sol = run_nlsolve_iip(quadratic_f!, u0; solver = alg(; autodiff)) |
| 60 | + @test SciMLBase.successful_retcode(sol) |
| 61 | + @test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9 |
| 62 | + end |
| 63 | + |
| 64 | + @testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, |
| 65 | + u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) |
| 66 | + |
| 67 | + probN = NonlinearProblem(quadratic_f, u0, 2.0) |
| 68 | + @test all(solve( |
| 69 | + probN, alg(; autodiff = AutoForwardDiff()); termination_condition).u .≈ |
| 70 | + sqrt(2.0)) |
| 71 | + end |
| 72 | + end |
| 73 | + end |
| 74 | +end |
| 75 | + |
| 76 | +@testitem "Second Order Methods" setup=[RootfindTestSnippet] tags=[:core] begin |
| 77 | + @testset for alg in ( |
| 78 | + SimpleHalley, |
| 79 | + ) |
| 80 | + @testset for autodiff in ( |
| 81 | + AutoForwardDiff(), |
| 82 | + AutoFiniteDiff(), |
| 83 | + AutoReverseDiff(), |
| 84 | + nothing |
| 85 | + ) |
| 86 | + @testset "[OOP] u0: $(typeof(u0))" for u0 in ( |
| 87 | + [1.0, 1.0], @SVector[1.0, 1.0], 1.0) |
| 88 | + sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff)) |
| 89 | + @test SciMLBase.successful_retcode(sol) |
| 90 | + @test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9 |
| 91 | + end |
| 92 | + end |
| 93 | + |
| 94 | + @testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, |
| 95 | + u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) |
| 96 | + |
| 97 | + probN = NonlinearProblem(quadratic_f, u0, 2.0) |
| 98 | + @test all(solve( |
| 99 | + probN, alg(; autodiff = AutoForwardDiff()); termination_condition).u .≈ |
| 100 | + sqrt(2.0)) |
| 101 | + end |
| 102 | + end |
| 103 | +end |
| 104 | + |
| 105 | +@testitem "Derivative Free Metods" setup=[RootfindTestSnippet] tags=[:core] begin |
| 106 | + @testset "$(nameof(typeof(alg)))" for alg in ( |
| 107 | + SimpleBroyden(), |
| 108 | + SimpleKlement(), |
| 109 | + SimpleDFSane(), |
| 110 | + SimpleLimitedMemoryBroyden(), |
| 111 | + SimpleBroyden(; linesearch = Val(true)), |
| 112 | + SimpleLimitedMemoryBroyden(; linesearch = Val(true)) |
| 113 | + ) |
| 114 | + @testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) |
| 115 | + sol = run_nlsolve_oop(quadratic_f, u0; solver = alg) |
| 116 | + @test SciMLBase.successful_retcode(sol) |
| 117 | + @test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9 |
| 118 | + end |
| 119 | + |
| 120 | + @testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],) |
| 121 | + sol = run_nlsolve_iip(quadratic_f!, u0; solver = alg) |
| 122 | + @test SciMLBase.successful_retcode(sol) |
| 123 | + @test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9 |
| 124 | + end |
| 125 | + |
| 126 | + @testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, |
| 127 | + u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) |
| 128 | + |
| 129 | + probN = NonlinearProblem(quadratic_f, u0, 2.0) |
| 130 | + @test all(solve(probN, alg; termination_condition).u .≈ sqrt(2.0)) |
| 131 | + end |
| 132 | + end |
| 133 | +end |
| 134 | + |
| 135 | +@testitem "Newton Fails" setup=[RootfindTestSnippet] tags=[:core] begin |
| 136 | + u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0] |
| 137 | + p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] |
| 138 | + |
| 139 | + @testset "$(nameof(typeof(alg)))" for alg in ( |
| 140 | + SimpleDFSane(), |
| 141 | + SimpleTrustRegion(), |
| 142 | + SimpleHalley(), |
| 143 | + SimpleTrustRegion(; nlsolve_update_rule = Val(true)) |
| 144 | + ) |
| 145 | + sol = run_nlsolve_oop(newton_fails, u0, p; solver = alg) |
| 146 | + @test SciMLBase.successful_retcode(sol) |
| 147 | + @test maximum(abs, newton_fails(sol.u, p)) < 1e-9 |
| 148 | + end |
| 149 | +end |
| 150 | + |
| 151 | +@testitem "Kwargs Propagation" setup=[RootfindTestSnippet] tags=[:core] begin |
| 152 | + prob = NonlinearProblem(quadratic_f, ones(4), 2.0; maxiters = 2) |
| 153 | + sol = solve(prob, SimpleNewtonRaphson()) |
| 154 | + @test sol.retcode === ReturnCode.MaxIters |
| 155 | +end |
0 commit comments