Skip to content

Commit a0c7473

Browse files
authored
Merge pull request #1 from JuliaComputing/myb/benchmark
Modify the benchmark and tests
2 parents 5b72760 + 63418b8 commit a0c7473

File tree

5 files changed

+47
-47
lines changed

5 files changed

+47
-47
lines changed

src/NonlinearSolve.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ module NonlinearSolve
2727

2828
# DiffEq styled algorithms
2929
export Bisection, Falsi, NewtonRaphson
30-
export ScalarBisection, ScalarNewton
3130

3231
export reinit!
3332
end # module

src/scalar.jl

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
1-
"""
2-
ScalarNewton
3-
4-
Fast Newton Raphson for scalar problems.
5-
"""
6-
struct ScalarNewton <: AbstractNonlinearSolveAlgorithm end
7-
8-
function DiffEqBase.solve(prob::NonlinearProblem{uType, false}, ::ScalarNewton, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...) where {uType}
1+
function DiffEqBase.solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
92
f = Base.Fix2(prob.f, prob.p)
103
x = float(prob.u0)
114
T = typeof(x)
@@ -15,26 +8,18 @@ function DiffEqBase.solve(prob::NonlinearProblem{uType, false}, ::ScalarNewton,
158
xo = oftype(x, Inf)
169
for i in 1:maxiters
1710
fx, dfx = value_derivative(f, x)
18-
iszero(fx) && return x
11+
iszero(fx) && return NewtonSolution(x, :Default)
1912
Δx = dfx \ fx
2013
x -= Δx
2114
if isapprox(x, xo, atol=atol, rtol=rtol)
22-
return x
15+
return NewtonSolution(x, :Default)
2316
end
2417
xo = x
2518
end
26-
return oftype(x, NaN)
19+
return NewtonSolution(x, :MaxitersExceeded)
2720
end
2821

29-
"""
30-
ScalarBisection
31-
32-
Fast Bisection for scalar problems. Note that it doesn't returns exact solution, but returns
33-
the best left limit of the exact solution.
34-
"""
35-
struct ScalarBisection <: AbstractNonlinearSolveAlgorithm end
36-
37-
function DiffEqBase.solve(prob::NonlinearProblem{uType, false}, ::ScalarBisection, args...; maxiters = 1000, kwargs...) where {uType}
22+
function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
3823
f = Base.Fix2(prob.f, prob.p)
3924
left, right = prob.u0
4025
fl, fr = f(left), f(right)

src/solve.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ end
99
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
1010
alias_u0 = false,
1111
maxiters = 1000,
12+
# bracketing algorithms only solve scalar problems
1213
immutable = (prob.u0 isa StaticArray || prob.u0 isa Number),
1314
kwargs...
1415
) where {uType, iip}
15-
16+
1617
if !(prob.u0 isa Tuple)
1718
error("You need to pass a tuple of u0 in bracketing algorithms.")
1819
end
@@ -60,7 +61,7 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewton
6061
fu = f(u, p)
6162
end
6263

63-
64+
6465
sol = build_newton_solution(u, Val(iip))
6566
if immutable
6667
return NewtonImmutableSolver(1, f, alg, u, fu, p, nothing, false, maxiters, internalnorm, :Default, tol, sol)
@@ -81,7 +82,7 @@ function DiffEqBase.solve!(solver::AbstractNonlinearSolver)
8182
if solver.iter == solver.maxiters
8283
solver.retcode = :MaxitersExceeded
8384
end
84-
set_solution!(solver)
85+
solver = set_solution(solver)
8586
return solver.sol
8687
end
8788

@@ -151,19 +152,25 @@ function check_for_exact_solution!(solver::BracketingSolver)
151152
return false
152153
end
153154

154-
function set_solution!(solver::BracketingSolver)
155-
solver.sol.left = solver.left
156-
solver.sol.right = solver.right
157-
solver.sol.retcode = solver.retcode
155+
function set_solution(solver::BracketingSolver)
156+
sol = solver.sol
157+
@set! sol.left = solver.left
158+
@set! sol.right = solver.right
159+
@set! sol.retcode = solver.retcode
160+
@set! solver.sol = sol
161+
return solver
158162
end
159163

160164
function get_solution(solver::BracketingImmutableSolver)
161165
return (left = solver.left, right = solver.right, retcode = solver.retcode)
162166
end
163167

164-
function set_solution!(solver::NewtonSolver)
165-
solver.sol.u = solver.u
166-
solver.sol.retcode = solver.retcode
168+
function set_solution(solver::NewtonSolver)
169+
sol = solver.sol
170+
@set! sol.u = solver.u
171+
@set! sol.retcode = solver.retcode
172+
@set! solver.sol = sol
173+
return solver
167174
end
168175

169176
function get_solution(solver::NewtonImmutableSolver)

src/types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function build_solution(u_prototype, ::Val{false})
8383
return BracketingSolution(zero(u_prototype), zero(u_prototype), :Default)
8484
end
8585

86-
mutable struct NewtonSolution{uType}
86+
struct NewtonSolution{uType}
8787
u::uType
8888
retcode::Symbol
8989
end

test/runtests.jl

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,35 @@ using StaticArrays
33
using BenchmarkTools
44
using Test
55

6-
function benchmark_immutable()
7-
probN = NonlinearProblem((u,p) -> u .* u .- 2, @SVector[1.0, 1.0])
6+
function benchmark_immutable(f, u0)
7+
probN = NonlinearProblem{false}(f, u0)
88
solver = init(probN, NewtonRaphson(), immutable = true, tol = 1e-9)
9-
sol = @btime solve!($solver)
10-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
9+
sol = solve!(solver)
1110
end
1211

13-
function benchmark_mutable()
14-
probN = NonlinearProblem((u,p) -> u .* u .- 2, @SVector[1.0, 1.0])
12+
function benchmark_mutable(f, u0)
13+
probN = NonlinearProblem{false}(f, u0)
1514
solver = init(probN, NewtonRaphson(), immutable = false, tol = 1e-9)
16-
sol = @btime (reinit!($solver, $probN); solve!($solver))
17-
@test all(sol.u .* sol.u .- 2 .< 1e-9)
15+
sol = (reinit!(solver, probN); solve!(solver))
1816
end
1917

20-
function benchmark_scalar()
21-
probN = NonlinearProblem((u,p) -> u .* u .- 2, 1.0)
22-
sol = @btime (solve($probN, ScalarNewton()))
23-
@test sol * sol - 2 < 1e-9
18+
function benchmark_scalar(f, u0)
19+
probN = NonlinearProblem{false}(f, u0)
20+
sol = (solve(probN, NewtonRaphson()))
2421
end
2522

26-
benchmark_immutable()
27-
benchmark_mutable()
28-
benchmark_scalar()
23+
f, u0 = (u,p) -> u .* u .- 2, @SVector[1.0, 1.0]
24+
sf, su0 = (u,p) -> u * u - 2, 1.0
25+
sol = benchmark_immutable(f, u0)
26+
@test sol.retcode === :Default
27+
@test all(sol.u .* sol.u .- 2 .< 1e-9)
28+
sol = benchmark_mutable(f, u0)
29+
@test sol.retcode === :Default
30+
@test all(sol.u .* sol.u .- 2 .< 1e-9)
31+
sol = benchmark_scalar(sf, su0)
32+
@test sol.retcode === :Default
33+
@test sol.u * sol.u - 2 < 1e-9
34+
35+
@test (@ballocated benchmark_immutable($f, $u0)) == 0
36+
@test (@ballocated benchmark_mutable($f, $u0)) < 200
37+
@test (@ballocated benchmark_scalar($sf, $su0)) == 0

0 commit comments

Comments
 (0)