Skip to content

Commit d5f18b7

Browse files
committed
test: comprehensive testing of root finding
1 parent 61ea504 commit d5f18b7

File tree

9 files changed

+195
-17
lines changed

9 files changed

+195
-17
lines changed

docs/src/release_notes.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,32 @@
11
# Release Notes
22

3-
## Breaking Changes in `NonlinearSolve.jl` v3
3+
## Oct '24
4+
5+
### Breaking Changes in `NonlinearSolve.jl` v4
6+
7+
### Breaking Changes in `SimpleNonlinearSolve.jl` v2
8+
9+
- `Auto*` structs are no longer exported. Load `ADTypes` to access them.
10+
- Use of termination conditions from `DiffEqBase` has been removed. Use the termination
11+
conditions from `NonlinearSolveBase` instead.
12+
- We no longer export the entire `SciMLBase`. Instead selected functionality relevant to
13+
`SimpleNonlinearSolve` has been exported.
14+
- If no autodiff is provided, we now choose from a list of autodiffs based on the packages
15+
loaded. For example, if `Enzyme` is loaded, we will default to that. In general, we
16+
don't guarantee the exact autodiff selected if `autodiff` is not provided (i.e.
17+
`nothing`).
18+
19+
## Dec '23
20+
21+
### Breaking Changes in `NonlinearSolve.jl` v3
422

523
- `GeneralBroyden` and `GeneralKlement` have been renamed to `Broyden` and `Klement`
624
respectively.
725
- Compat for `SimpleNonlinearSolve` has been bumped to `v1`.
826
- The old style of specifying autodiff with `chunksize`, `standardtag`, etc. has been
927
deprecated in favor of directly specifying the autodiff type, like `AutoForwardDiff`.
1028

11-
## Breaking Changes in `SimpleNonlinearSolve.jl` v1
29+
### Breaking Changes in `SimpleNonlinearSolve.jl` v1
1230

1331
- Batched solvers have been removed in favor of `BatchedArrays.jl`. Stay tuned for detailed
1432
tutorials on how to use `BatchedArrays.jl` with `NonlinearSolve` & `SimpleNonlinearSolve`

lib/NonlinearSolveBase/src/termination_conditions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function update_u!!(cache::NonlinearTerminationModeCache, u)
2626
if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u)
2727
copyto!(cache.u, u)
2828
else
29-
cache.u .= u
29+
cache.u = u
3030
end
3131
end
3232

@@ -60,6 +60,8 @@ function SciMLBase.init(
6060
else
6161
u_diff_cache = u_unaliased
6262
end
63+
best_value = initial_objective
64+
max_stalled_steps = mode.max_stalled_steps
6365
else
6466
initial_objective = nothing
6567
objectives_trace = nothing

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
module SimpleNonlinearSolve
22

33
using Accessors: @reset
4-
using CommonSolve: CommonSolve, solve
4+
using CommonSolve: CommonSolve, solve, init, solve!
55
using ConcreteStructs: @concrete
66
using FastClosures: @closure
77
using LineSearch: LiFukushimaLineSearch
88
using LinearAlgebra: LinearAlgebra, dot
99
using MaybeInplace: @bb, setindex_trait, CannotSetindex, CanSetindex
1010
using PrecompileTools: @compile_workload, @setup_workload
11-
using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, NonlinearLeastSquaresProblem,
12-
IntervalNonlinearProblem, ReturnCode
11+
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, NonlinearFunction, NonlinearProblem,
12+
NonlinearLeastSquaresProblem, IntervalNonlinearProblem, ReturnCode, remake
1313
using StaticArraysCore: StaticArray, SArray, SVector, MArray
1414

1515
# AD Dependencies
16-
using ADTypes: ADTypes
16+
using ADTypes: ADTypes, AutoForwardDiff
1717
using DifferentiationInterface: DifferentiationInterface
1818
using FiniteDiff: FiniteDiff
1919
using ForwardDiff: ForwardDiff
@@ -121,7 +121,7 @@ export IntervalNonlinearProblem
121121

122122
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
123123

124-
export NonlinearProblem, NonlinearLeastSquaresProblem
124+
export NonlinearFunction, NonlinearProblem, NonlinearLeastSquaresProblem
125125

126126
export SimpleBroyden, SimpleKlement, SimpleLimitedMemoryBroyden
127127
export SimpleDFSane

lib/SimpleNonlinearSolve/src/halley.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ function SciMLBase.__solve(
3535
abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
3636
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))
3737

38-
autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
38+
# The way we write the 2nd order derivatives, we know Enzyme won't work there
39+
autodiff = alg.autodiff === nothing ? AutoForwardDiff() : alg.autodiff
3940

4041
@bb xo = copy(x)
4142

lib/SimpleNonlinearSolve/src/lbroyden.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function SciMLBase.__solve(
3838
args...; termination_condition = nothing, kwargs...)
3939
if prob.u0 isa SArray
4040
if termination_condition === nothing ||
41-
termination_condition isa AbsNormTerminationMode
41+
termination_condition isa NonlinearSolveBase.AbsNormTerminationMode
4242
return internal_static_solve(
4343
prob, alg, args...; termination_condition, kwargs...)
4444
end

lib/SimpleNonlinearSolve/test/core/23_test_problems_tests.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testsnippet RobustnessTestSnippet begin
2-
using NonlinearProblemLibrary, NonlinearSolveBase, LinearAlgebra
2+
using NonlinearProblemLibrary, NonlinearSolveBase, LinearAlgebra, ADTypes
33

44
problems = NonlinearProblemLibrary.problems
55
dicts = NonlinearProblemLibrary.dicts
@@ -40,7 +40,7 @@
4040
end
4141

4242
@testitem "23 Test Problems: SimpleNewtonRaphson" setup=[RobustnessTestSnippet] tags=[:core] begin
43-
alg_ops = (SimpleNewtonRaphson(),)
43+
alg_ops = (SimpleNewtonRaphson(; autodiff = AutoForwardDiff()),)
4444

4545
broken_tests = Dict(alg => Int[] for alg in alg_ops)
4646
broken_tests[alg_ops[1]] = []
@@ -49,7 +49,7 @@ end
4949
end
5050

5151
@testitem "23 Test Problems: SimpleHalley" setup=[RobustnessTestSnippet] tags=[:core] begin
52-
alg_ops = (SimpleHalley(),)
52+
alg_ops = (SimpleHalley(; autodiff = AutoForwardDiff()),)
5353

5454
broken_tests = Dict(alg => Int[] for alg in alg_ops)
5555
if Sys.isapple()
@@ -62,7 +62,10 @@ end
6262
end
6363

6464
@testitem "23 Test Problems: SimpleTrustRegion" setup=[RobustnessTestSnippet] tags=[:core] begin
65-
alg_ops = (SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)))
65+
alg_ops = (
66+
SimpleTrustRegion(; autodiff = AutoForwardDiff()),
67+
SimpleTrustRegion(; nlsolve_update_rule = Val(true), autodiff = AutoForwardDiff())
68+
)
6669

6770
broken_tests = Dict(alg => Int[] for alg in alg_ops)
6871
broken_tests[alg_ops[1]] = [3, 15, 16, 21]

lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testitem "BigFloat Support" tags=[:core] begin
2-
using SimpleNonlinearSolve, LinearAlgebra
2+
using SimpleNonlinearSolve, LinearAlgebra, ADTypes, SciMLBase
33

44
fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p)
55
fn_oop = NonlinearFunction{false}((u, p) -> u .* u .- p)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,155 @@
1+
@testsnippet RootfindTestSnippet begin
2+
using StaticArrays, Random, LinearAlgebra, ForwardDiff, NonlinearSolveBase, SciMLBase
3+
using PolyesterForwardDiff, Enzyme, ReverseDiff
14

5+
quadratic_f(u, p) = u .* u .- p
6+
quadratic_f!(du, u, p) = (du .= u .* u .- p)
7+
8+
function newton_fails(u, p)
9+
return 0.010000000000000002 .+
10+
10.000000000000002 ./ (1 .+
11+
(0.21640425613334457 .+
12+
216.40425613334457 ./ (1 .+
13+
(0.21640425613334457 .+
14+
216.40425613334457 ./ (1 .+ 0.0006250000000000001(u .^ 2.0))) .^ 2.0)) .^
15+
2.0) .- 0.0011552453009332421u .- p
16+
end
17+
18+
const TERMINATION_CONDITIONS = [
19+
NormTerminationMode(Base.Fix1(maximum, abs)),
20+
RelTerminationMode(),
21+
RelNormTerminationMode(Base.Fix1(maximum, abs)),
22+
RelNormSafeTerminationMode(Base.Fix1(maximum, abs)),
23+
RelNormSafeBestTerminationMode(Base.Fix1(maximum, abs)),
24+
AbsTerminationMode(),
25+
AbsNormTerminationMode(Base.Fix1(maximum, abs)),
26+
AbsNormSafeTerminationMode(Base.Fix1(maximum, abs)),
27+
AbsNormSafeBestTerminationMode(Base.Fix1(maximum, abs))
28+
]
29+
30+
function run_nlsolve_oop(f::F, u0, p = 2.0; solver) where {F}
31+
return solve(NonlinearProblem{false}(f, u0, p), solver; abstol = 1e-9)
32+
end
33+
function run_nlsolve_iip(f!::F, u0, p = 2.0; solver) where {F}
34+
return solve(NonlinearProblem{true}(f!, u0, p), solver; abstol = 1e-9)
35+
end
36+
end
37+
38+
@testitem "First Order Methods" setup=[RootfindTestSnippet] tags=[:core] begin
39+
@testset for alg in (
40+
SimpleNewtonRaphson,
41+
SimpleTrustRegion,
42+
(; kwargs...) -> SimpleTrustRegion(; kwargs..., nlsolve_update_rule = Val(true))
43+
)
44+
@testset for autodiff in (
45+
AutoForwardDiff(),
46+
AutoFiniteDiff(),
47+
AutoReverseDiff(),
48+
AutoEnzyme(),
49+
nothing
50+
)
51+
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
52+
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
53+
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff))
54+
@test SciMLBase.successful_retcode(sol)
55+
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
56+
end
57+
58+
@testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],)
59+
sol = run_nlsolve_iip(quadratic_f!, u0; solver = alg(; autodiff))
60+
@test SciMLBase.successful_retcode(sol)
61+
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
62+
end
63+
64+
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
65+
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
66+
67+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
68+
@test all(solve(
69+
probN, alg(; autodiff = AutoForwardDiff()); termination_condition).u .≈
70+
sqrt(2.0))
71+
end
72+
end
73+
end
74+
end
75+
76+
@testitem "Second Order Methods" setup=[RootfindTestSnippet] tags=[:core] begin
77+
@testset for alg in (
78+
SimpleHalley,
79+
)
80+
@testset for autodiff in (
81+
AutoForwardDiff(),
82+
AutoFiniteDiff(),
83+
AutoReverseDiff(),
84+
nothing
85+
)
86+
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
87+
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
88+
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff))
89+
@test SciMLBase.successful_retcode(sol)
90+
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
91+
end
92+
end
93+
94+
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
95+
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
96+
97+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
98+
@test all(solve(
99+
probN, alg(; autodiff = AutoForwardDiff()); termination_condition).u .≈
100+
sqrt(2.0))
101+
end
102+
end
103+
end
104+
105+
@testitem "Derivative Free Metods" setup=[RootfindTestSnippet] tags=[:core] begin
106+
@testset "$(nameof(typeof(alg)))" for alg in (
107+
SimpleBroyden(),
108+
SimpleKlement(),
109+
SimpleDFSane(),
110+
SimpleLimitedMemoryBroyden(),
111+
SimpleBroyden(; linesearch = Val(true)),
112+
SimpleLimitedMemoryBroyden(; linesearch = Val(true))
113+
)
114+
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
115+
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg)
116+
@test SciMLBase.successful_retcode(sol)
117+
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
118+
end
119+
120+
@testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],)
121+
sol = run_nlsolve_iip(quadratic_f!, u0; solver = alg)
122+
@test SciMLBase.successful_retcode(sol)
123+
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
124+
end
125+
126+
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
127+
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
128+
129+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
130+
@test all(solve(probN, alg; termination_condition).u .≈ sqrt(2.0))
131+
end
132+
end
133+
end
134+
135+
@testitem "Newton Fails" setup=[RootfindTestSnippet] tags=[:core] begin
136+
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
137+
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
138+
139+
@testset "$(nameof(typeof(alg)))" for alg in (
140+
SimpleDFSane(),
141+
SimpleTrustRegion(),
142+
SimpleHalley(),
143+
SimpleTrustRegion(; nlsolve_update_rule = Val(true))
144+
)
145+
sol = run_nlsolve_oop(newton_fails, u0, p; solver = alg)
146+
@test SciMLBase.successful_retcode(sol)
147+
@test maximum(abs, newton_fails(sol.u, p)) < 1e-9
148+
end
149+
end
150+
151+
@testitem "Kwargs Propagation" setup=[RootfindTestSnippet] tags=[:core] begin
152+
prob = NonlinearProblem(quadratic_f, ones(4), 2.0; maxiters = 2)
153+
sol = solve(prob, SimpleNewtonRaphson())
154+
@test sol.retcode === ReturnCode.MaxIters
155+
end

lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testitem "Solving on CUDA" tags=[:cuda] begin
2-
using StaticArrays, CUDA, SimpleNonlinearSolve
2+
using StaticArrays, CUDA, SimpleNonlinearSolve, ADTypes
33

44
if CUDA.functional()
55
CUDA.allowscalar(false)
@@ -47,7 +47,7 @@
4747
end
4848

4949
@testitem "CUDA Kernel Launch Test" tags=[:cuda] begin
50-
using StaticArrays, CUDA, SimpleNonlinearSolve
50+
using StaticArrays, CUDA, SimpleNonlinearSolve, ADTypes
5151
using NonlinearSolveBase: ImmutableNonlinearProblem
5252

5353
if CUDA.functional()

0 commit comments

Comments
 (0)