|
| 1 | +using NonlinearSolve |
| 2 | +using StaticArrays |
| 3 | +using BenchmarkTools |
| 4 | +using Test |
| 5 | + |
| 6 | +function benchmark_immutable(f, u0) |
| 7 | + probN = NonlinearProblem{false}(f, u0) |
| 8 | + solver = init(probN, NewtonRaphson(), tol = 1e-9) |
| 9 | + sol = solve!(solver) |
| 10 | +end |
| 11 | + |
| 12 | +function benchmark_mutable(f, u0) |
| 13 | + probN = NonlinearProblem{false}(f, u0) |
| 14 | + solver = init(probN, NewtonRaphson(), tol = 1e-9) |
| 15 | + sol = (reinit!(solver, probN); solve!(solver)) |
| 16 | +end |
| 17 | + |
| 18 | +function benchmark_scalar(f, u0) |
| 19 | + probN = NonlinearProblem{false}(f, u0) |
| 20 | + sol = (solve(probN, NewtonRaphson())) |
| 21 | +end |
| 22 | + |
| 23 | +function ff(u,p) |
| 24 | + u .* u .- 2 |
| 25 | +end |
| 26 | +const cu0 = @SVector[1.0, 1.0] |
| 27 | +function sf(u,p) |
| 28 | + u * u - 2 |
| 29 | +end |
| 30 | +const csu0 = 1.0 |
| 31 | + |
| 32 | +sol = benchmark_immutable(ff, cu0) |
| 33 | +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) |
| 34 | +@test all(sol.u .* sol.u .- 2 .< 1e-9) |
| 35 | +sol = benchmark_mutable(ff, cu0) |
| 36 | +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) |
| 37 | +@test all(sol.u .* sol.u .- 2 .< 1e-9) |
| 38 | +sol = benchmark_scalar(sf, csu0) |
| 39 | +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) |
| 40 | +@test sol.u * sol.u - 2 < 1e-9 |
| 41 | + |
| 42 | +@test (@ballocated benchmark_immutable(ff, cu0)) == 0 |
| 43 | +@test (@ballocated benchmark_mutable(ff, cu0)) < 200 |
| 44 | +@test (@ballocated benchmark_scalar(sf, csu0)) == 0 |
| 45 | + |
| 46 | +# AD Tests |
| 47 | +using ForwardDiff |
| 48 | + |
| 49 | +# Immutable |
| 50 | +f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0] |
| 51 | + |
| 52 | +g = function (p) |
| 53 | + probN = NonlinearProblem{false}(f, u0, p) |
| 54 | + sol = solve(probN, NewtonRaphson(), tol = 1e-9) |
| 55 | + return sol.u[end] |
| 56 | +end |
| 57 | + |
| 58 | +for p in 1.0:0.1:100.0 |
| 59 | + @test g(p) ≈ sqrt(p) |
| 60 | + @test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p)) |
| 61 | +end |
| 62 | + |
| 63 | +# Scalar |
| 64 | +f, u0 = (u, p) -> u * u - p, 1.0 |
| 65 | + |
| 66 | +# NewtonRaphson |
| 67 | +g = function (p) |
| 68 | + probN = NonlinearProblem{false}(f, oftype(p, u0), p) |
| 69 | + sol = solve(probN, NewtonRaphson()) |
| 70 | + return sol.u |
| 71 | +end |
| 72 | + |
| 73 | +@test ForwardDiff.derivative(g, 1.0) ≈ 0.5 |
| 74 | + |
| 75 | +for p in 1.1:0.1:100.0 |
| 76 | + @test g(p) ≈ sqrt(p) |
| 77 | + @test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p)) |
| 78 | +end |
| 79 | + |
| 80 | +u0 = (1.0, 20.0) |
| 81 | +# Falsi |
| 82 | +g = function (p) |
| 83 | + probN = NonlinearProblem{false}(f, typeof(p).(u0), p) |
| 84 | + sol = solve(probN, Falsi()) |
| 85 | + return sol.left |
| 86 | +end |
| 87 | + |
| 88 | +for p in 1.1:0.1:100.0 |
| 89 | + @test g(p) ≈ sqrt(p) |
| 90 | + @test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p)) |
| 91 | +end |
| 92 | + |
| 93 | +f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0) |
| 94 | +t = (p) -> [sqrt(p[2] / p[1])] |
| 95 | +p = [0.9, 50.0] |
| 96 | +for alg in [Bisection(), Falsi()] |
| 97 | + global g, p |
| 98 | + g = function (p) |
| 99 | + probN = NonlinearProblem{false}(f, u0, p) |
| 100 | + sol = solve(probN, Bisection()) |
| 101 | + return [sol.left] |
| 102 | + end |
| 103 | + |
| 104 | + @test g(p) ≈ [sqrt(p[2] / p[1])] |
| 105 | + @test ForwardDiff.jacobian(g, p) ≈ ForwardDiff.jacobian(t, p) |
| 106 | +end |
| 107 | + |
| 108 | +gnewton = function (p) |
| 109 | + probN = NonlinearProblem{false}(f, 0.5, p) |
| 110 | + sol = solve(probN, NewtonRaphson()) |
| 111 | + return [sol.u] |
| 112 | +end |
| 113 | +@test gnewton(p) ≈ [sqrt(p[2] / p[1])] |
| 114 | +@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) |
| 115 | + |
| 116 | +# Error Checks |
| 117 | + |
| 118 | +f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0] |
| 119 | +probN = NonlinearProblem(f, u0) |
| 120 | + |
| 121 | +@test solve(probN, NewtonRaphson()).u[end] ≈ sqrt(2.0) |
| 122 | +@test solve(probN, NewtonRaphson(); immutable = false).u[end] ≈ sqrt(2.0) |
| 123 | +@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] ≈ sqrt(2.0) |
| 124 | +@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] ≈ sqrt(2.0) |
| 125 | + |
| 126 | +for u0 in [1.0, [1, 1.0]] |
| 127 | + local f, probN, sol |
| 128 | + f = (u, p) -> u .* u .- 2.0 |
| 129 | + probN = NonlinearProblem(f, u0) |
| 130 | + sol = sqrt(2) * u0 |
| 131 | + |
| 132 | + @test solve(probN, NewtonRaphson()).u ≈ sol |
| 133 | + @test solve(probN, NewtonRaphson()).u ≈ sol |
| 134 | + @test solve(probN, NewtonRaphson(;autodiff=false)).u ≈ sol |
| 135 | +end |
| 136 | + |
| 137 | +# Bisection Tests |
| 138 | +f, u0 = (u, p) -> u .* u .- 2.0, (1.0, 2.0) |
| 139 | +probB = NonlinearProblem(f, u0) |
| 140 | + |
| 141 | +# Falsi |
| 142 | +solver = init(probB, Falsi()) |
| 143 | +sol = solve!(solver) |
| 144 | +@test sol.left ≈ sqrt(2.0) |
| 145 | + |
| 146 | +# this should call the fast scalar overload |
| 147 | +@test solve(probB, Bisection()).left ≈ sqrt(2.0) |
| 148 | + |
| 149 | +# these should call the iterator version |
| 150 | +solver = init(probB, Bisection()) |
| 151 | +@test solver isa NonlinearSolve.BracketingImmutableSolver |
| 152 | +@test solve!(solver).left ≈ sqrt(2.0) |
| 153 | + |
| 154 | +# Garuntee Tests for Bisection |
| 155 | +f = function (u, p) |
| 156 | + if u < 2.0 |
| 157 | + return u - 2.0 |
| 158 | + elseif u > 3.0 |
| 159 | + return u - 3.0 |
| 160 | + else |
| 161 | + return 0.0 |
| 162 | + end |
| 163 | +end |
| 164 | +probB = NonlinearProblem(f, (0.0, 4.0)) |
| 165 | + |
| 166 | +solver = init(probB, Bisection(;exact_left = true)) |
| 167 | +sol = solve!(solver) |
| 168 | +@test f(sol.left, nothing) < 0.0 |
| 169 | +@test f(nextfloat(sol.left), nothing) >= 0.0 |
| 170 | + |
| 171 | +solver = init(probB, Bisection(;exact_right = true)) |
| 172 | +sol = solve!(solver) |
| 173 | +@test f(sol.right, nothing) > 0.0 |
| 174 | +@test f(prevfloat(sol.right), nothing) <= 0.0 |
| 175 | + |
| 176 | +solver = init(probB, Bisection(;exact_left = true, exact_right = true); immutable = false) |
| 177 | +sol = solve!(solver) |
| 178 | +@test f(sol.left, nothing) < 0.0 |
| 179 | +@test f(nextfloat(sol.left), nothing) >= 0.0 |
| 180 | +@test f(sol.right, nothing) > 0.0 |
| 181 | +@test f(prevfloat(sol.right), nothing) <= 0.0 |
0 commit comments