Skip to content

Commit 004634d

Browse files
committed
test(NonlinearSolveSpectralMethods): add tests and ci scripts
1 parent b8b210d commit 004634d

File tree

9 files changed

+285
-72
lines changed

9 files changed

+285
-72
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
name: CI (NonlinearSolveSpectralMethods)
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- master
7+
paths:
8+
- "lib/NonlinearSolveSpectralMethods/**"
9+
- ".github/workflows/CI_NonlinearSolveSpectralMethods.yml"
10+
- "lib/NonlinearSolveBase/**"
11+
- "lib/SciMLJacobianOperators/**"
12+
push:
13+
branches:
14+
- master
15+
16+
concurrency:
17+
# Skip intermediate builds: always.
18+
# Cancel intermediate builds: only if it is a pull request build.
19+
group: ${{ github.workflow }}-${{ github.ref }}
20+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
21+
22+
jobs:
23+
test:
24+
runs-on: ${{ matrix.os }}
25+
strategy:
26+
fail-fast: false
27+
matrix:
28+
version:
29+
- "lts"
30+
- "1"
31+
os:
32+
- ubuntu-latest
33+
- macos-latest
34+
- windows-latest
35+
steps:
36+
- uses: actions/checkout@v4
37+
- uses: julia-actions/setup-julia@v2
38+
with:
39+
version: ${{ matrix.version }}
40+
- uses: actions/cache@v4
41+
env:
42+
cache-name: cache-artifacts
43+
with:
44+
path: ~/.julia/artifacts
45+
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
46+
restore-keys: |
47+
${{ runner.os }}-test-${{ env.cache-name }}-
48+
${{ runner.os }}-test-
49+
${{ runner.os }}-
50+
- name: "Install Dependencies and Run Tests"
51+
run: |
52+
import Pkg
53+
Pkg.Registry.update()
54+
# Install packages present in subdirectories
55+
dev_pks = Pkg.PackageSpec[]
56+
for path in ("lib/SciMLJacobianOperators", "lib/NonlinearSolveBase")
57+
push!(dev_pks, Pkg.PackageSpec(; path))
58+
end
59+
Pkg.develop(dev_pks)
60+
Pkg.instantiate()
61+
Pkg.test(; coverage="user")
62+
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/NonlinearSolveSpectralMethods {0}
63+
- uses: julia-actions/julia-processcoverage@v1
64+
with:
65+
directories: lib/NonlinearSolveSpectralMethods/src,lib/NonlinearSolveBase/src,lib/NonlinearSolveBase/ext,lib/SciMLJacobianOperators/src
66+
- uses: codecov/codecov-action@v4
67+
with:
68+
file: lcov.info
69+
token: ${{ secrets.CODECOV_TOKEN }}
70+
verbose: true
71+
fail_ci_if_error: true

common/common_core_testing.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using NonlinearSolveBase, SciMLBase
2+
3+
const TERMINATION_CONDITIONS = [
4+
NormTerminationMode(Base.Fix1(maximum, abs)),
5+
RelTerminationMode(),
6+
RelNormTerminationMode(Base.Fix1(maximum, abs)),
7+
RelNormSafeTerminationMode(Base.Fix1(maximum, abs)),
8+
RelNormSafeBestTerminationMode(Base.Fix1(maximum, abs)),
9+
AbsTerminationMode(),
10+
AbsNormTerminationMode(Base.Fix1(maximum, abs)),
11+
AbsNormSafeTerminationMode(Base.Fix1(maximum, abs)),
12+
AbsNormSafeBestTerminationMode(Base.Fix1(maximum, abs))
13+
]
14+
15+
quadratic_f(u, p) = u .* u .- p
16+
quadratic_f!(du, u, p) = (du .= u .* u .- p)
17+
quadratic_f2(u, p) = @. p[1] * u * u - p[2]
18+
19+
function newton_fails(u, p)
20+
return 0.010000000000000002 .+
21+
10.000000000000002 ./ (1 .+
22+
(0.21640425613334457 .+
23+
216.40425613334457 ./ (1 .+
24+
(0.21640425613334457 .+
25+
216.40425613334457 ./ (1 .+ 0.0006250000000000001(u .^ 2.0))) .^ 2.0)) .^
26+
2.0) .- 0.0011552453009332421u .- p
27+
end
28+
29+
function solve_oop(f, u0, p = 2.0; solver, kwargs...)
30+
prob = NonlinearProblem{false}(f, u0, p)
31+
return solve(prob, solver; abstol = 1e-9, kwargs...)
32+
end
33+
34+
function solve_iip(f, u0, p = 2.0; solver, kwargs...)
35+
prob = NonlinearProblem{true}(f, u0, p)
36+
return solve(prob, solver; abstol = 1e-9, kwargs...)
37+
end
38+
39+
function nlprob_iterator_interface(f, p_range, isinplace, solver)
40+
probN = NonlinearProblem{isinplace}(f, isinplace ? [0.5] : 0.5, p_range[begin])
41+
cache = init(probN, solver; maxiters = 100, abstol = 1e-10)
42+
sols = zeros(length(p_range))
43+
for (i, p) in enumerate(p_range)
44+
reinit!(cache, isinplace ? [cache.u[1]] : cache.u; p = p)
45+
sol = solve!(cache)
46+
sols[i] = isinplace ? sol.u[1] : sol.u
47+
end
48+
return sols
49+
end
50+
51+
export TERMINATION_CONDITIONS
52+
export quadratic_f, quadratic_f!, quadratic_f2, newton_fails
53+
export solve_oop, solve_iip
54+
export nlprob_iterator_interface

lib/NonlinearSolveBase/src/utils.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ArrayInterface: ArrayInterface
44
using FastClosures: @closure
55
using LinearAlgebra: LinearAlgebra, Diagonal, Symmetric, norm, dot, cond, diagind, pinv
66
using MaybeInplace: @bb
7-
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
7+
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition, recursivecopy!
88
using SciMLOperators: AbstractSciMLOperator
99
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearFunction
1010
using StaticArraysCore: StaticArray, SArray, SMatrix
@@ -242,6 +242,17 @@ function make_identity!!(A::AbstractMatrix{T}, α) where {T}
242242
return A
243243
end
244244

245+
function reinit_common!(cache, u0, p, alias_u0::Bool)
246+
if SciMLBase.isinplace(cache)
247+
recursivecopy!(cache.u, u0)
248+
cache.prob.f(cache.fu, cache.u, p)
249+
else
250+
cache.u = maybe_unaliased(u0, alias_u0)
251+
NonlinearSolveBase.set_fu!(cache, cache.prob.f(u0, p))
252+
end
253+
cache.p = p
254+
end
255+
245256
function clean_sprint_struct(x)
246257
x isa Symbol && return "$(Meta.quot(x))"
247258
x isa Number && return string(x)

lib/NonlinearSolveSpectralMethods/Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
88
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
99
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1010
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
11-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
1312
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1413
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -17,14 +16,14 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1716

1817
[compat]
1918
Aqua = "0.8"
19+
BenchmarkTools = "1.5.0"
2020
CommonSolve = "0.2.4"
2121
ConcreteStructs = "0.2.3"
2222
DiffEqBase = "6.155.3"
2323
ExplicitImports = "1.5"
2424
Hwloc = "3"
2525
InteractiveUtils = "<0.0.1, 1"
2626
LineSearch = "0.1.4"
27-
LinearAlgebra = "1.11.0"
2827
MaybeInplace = "0.1.4"
2928
NonlinearProblemLibrary = "0.1.2"
3029
NonlinearSolveBase = "1.1"
@@ -34,19 +33,22 @@ ReTestItems = "1.24"
3433
Reexport = "1"
3534
SciMLBase = "2.54"
3635
StableRNGs = "1"
36+
StaticArrays = "1.9.8"
3737
Test = "1.10"
3838
julia = "1.10"
3939

4040
[extras]
4141
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
42+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
4243
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
4344
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
4445
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
4546
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
4647
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
4748
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
4849
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
50+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4951
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5052

5153
[targets]
52-
test = ["Aqua", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "Test"]
54+
test = ["Aqua", "BenchmarkTools", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test"]

lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
module NonlinearSolveSpectralMethods
22

3+
using ConcreteStructs: @concrete
34
using Reexport: @reexport
45
using PrecompileTools: @compile_workload, @setup_workload
56

67
using CommonSolve: CommonSolve
7-
using ConcreteStructs: @concrete
88
using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches
9-
using LinearAlgebra: dot
109
using LineSearch: RobustNonMonotoneLineSearch
1110
using MaybeInplace: @bb
1211
using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
@@ -19,9 +18,7 @@ include("dfsane.jl")
1918
include("solve.jl")
2019

2120
@setup_workload begin
22-
include(joinpath(
23-
@__DIR__, "..", "..", "..", "common", "nonlinear_problem_workloads.jl"
24-
))
21+
include("../../../common/nonlinear_problem_workloads.jl")
2522

2623
algs = [DFSane()]
2724

lib/NonlinearSolveSpectralMethods/src/solve.jl

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -70,43 +70,41 @@ end
7070
kwargs
7171
end
7272

73-
# XXX: Implement
74-
# function __reinit_internal!(
75-
# cache::GeneralizedDFSaneCache{iip}, args...; p = cache.p, u0 = cache.u,
76-
# alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...) where {iip}
77-
# if iip
78-
# recursivecopy!(cache.u, u0)
79-
# cache.prob.f(cache.fu, cache.u, p)
80-
# else
81-
# cache.u = __maybe_unaliased(u0, alias_u0)
82-
# set_fu!(cache, cache.prob.f(cache.u, p))
83-
# end
84-
# cache.p = p
85-
86-
# if cache.alg.σ_1 === nothing
87-
# σ_n = dot(cache.u, cache.u) / dot(cache.u, cache.fu)
88-
# # Spectral parameter bounds check
89-
# if !(cache.alg.σ_min ≤ abs(σ_n) ≤ cache.alg.σ_max)
90-
# test_norm = dot(cache.fu, cache.fu)
91-
# σ_n = clamp(inv(test_norm), T(1), T(1e5))
92-
# end
93-
# else
94-
# σ_n = T(cache.alg.σ_1)
95-
# end
96-
# cache.σ_n = σ_n
97-
98-
# reset_timer!(cache.timer)
99-
# cache.total_time = 0.0
100-
101-
# reset!(cache.trace)
102-
# reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
103-
# __reinit_internal!(cache.stats)
104-
# cache.nsteps = 0
105-
# cache.maxiters = maxiters
106-
# cache.maxtime = maxtime
107-
# cache.force_stop = false
108-
# cache.retcode = ReturnCode.Default
109-
# end
73+
function InternalAPI.reinit_self!(
74+
cache::GeneralizedDFSaneCache, args...; p = cache.p, u0 = cache.u,
75+
alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...
76+
)
77+
Utils.reinit_common!(cache, u0, p, alias_u0)
78+
79+
if cache.alg.σ_1 === nothing
80+
σ_n = Utils.safe_dot(cache.u, cache.u) / Utils.safe_dot(cache.u, cache.fu)
81+
# Spectral parameter bounds check
82+
if !(cache.alg.σ_min abs(σ_n) cache.alg.σ_max)
83+
test_norm = NonlinearSolveBase.L2_NORM(cache.fu)
84+
σ_n = clamp(inv(test_norm), T(1), T(1e5))
85+
end
86+
else
87+
σ_n = T(cache.alg.σ_1)
88+
end
89+
cache.σ_n = σ_n
90+
91+
NonlinearSolveBase.reset_timer!(cache.timer)
92+
cache.total_time = 0.0
93+
94+
NonlinearSolveBase.reset!(cache.trace)
95+
SciMLBase.reinit!(
96+
cache.termination_cache, NonlinearSolveBase.get_fu(cache),
97+
NonlinearSolveBase.get_u(cache); kwargs...
98+
)
99+
100+
InternalAPI.reinit!(cache.stats)
101+
cache.nsteps = 0
102+
cache.maxiters = maxiters
103+
cache.maxtime = maxtime
104+
cache.force_stop = false
105+
cache.retcode = ReturnCode.Default
106+
return
107+
end
110108

111109
NonlinearSolveBase.@internal_caches GeneralizedDFSaneCache :linesearch_cache
112110

@@ -137,7 +135,7 @@ function SciMLBase.__init(
137135
)
138136

139137
if alg.σ_1 === nothing
140-
σ_n = dot(u, u) / dot(u, fu)
138+
σ_n = Utils.safe_dot(u, u) / Utils.safe_dot(u, fu)
141139
# Spectral parameter bounds check
142140
if !(alg.σ_min abs(σ_n) alg.σ_max)
143141
test_norm = NonlinearSolveBase.L2_NORM(fu)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
@testsetup module CoreRootfindTesting
2+
3+
include("../../../common/common_core_testing.jl")
4+
5+
end
6+
7+
@testitem "DFSane" setup=[CoreRootfindTesting] tags=[:core] begin
8+
using BenchmarkTools: @ballocated
9+
using StaticArrays: @SVector
10+
11+
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
12+
13+
@testset "[OOP] u0: $(typeof(u0))" for u0 in u0s
14+
sol = solve_oop(quadratic_f, u0; solver = DFSane())
15+
@test SciMLBase.successful_retcode(sol)
16+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
17+
18+
cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), DFSane(), abstol = 1e-9)
19+
@test (@ballocated solve!($cache)) < 200
20+
end
21+
22+
@testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],)
23+
sol = solve_iip(quadratic_f!, u0; solver = DFSane())
24+
@test SciMLBase.successful_retcode(sol)
25+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
26+
27+
cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0), DFSane(), abstol = 1e-9)
28+
@test (@ballocated solve!($cache)) 64
29+
end
30+
end
31+
32+
@testitem "DFSane Iterator Interface" setup=[CoreRootfindTesting] tags=[:core] begin
33+
p = range(0.01, 2, length = 200)
34+
@test nlprob_iterator_interface(quadratic_f, p, false, DFSane()) sqrt.(p)
35+
@test nlprob_iterator_interface(quadratic_f!, p, true, DFSane()) sqrt.(p)
36+
end
37+
38+
@testitem "DFSane NewtonRaphson Fails" setup=[CoreRootfindTesting] tags=[:core] begin
39+
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
40+
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
41+
sol = solve_oop(newton_fails, u0, p; solver = DFSane())
42+
@test SciMLBase.successful_retcode(sol)
43+
@test all(abs.(newton_fails(sol.u, p)) .< 1e-9)
44+
end
45+
46+
@testitem "DFSane: Kwargs" setup=[CoreRootfindTesting] tags=[:core] begin
47+
σ_min = [1e-10, 1e-5, 1e-4]
48+
σ_max = [1e10, 1e5, 1e4]
49+
σ_1 = [1.0, 0.5, 2.0]
50+
M = [10, 1, 100]
51+
γ = [1e-4, 1e-3, 1e-5]
52+
τ_min = [0.1, 0.2, 0.3]
53+
τ_max = [0.5, 0.8, 0.9]
54+
nexp = [2, 1, 2]
55+
η_strategy = [
56+
(f_1, k, x, F) -> f_1 / k^2, (f_1, k, x, F) -> f_1 / k^3,
57+
(f_1, k, x, F) -> f_1 / k^4
58+
]
59+
60+
list_of_options = zip(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, nexp, η_strategy)
61+
for options in list_of_options
62+
local probN, sol, alg
63+
alg = DFSane(;
64+
sigma_min = options[1], sigma_max = options[2], sigma_1 = options[3],
65+
M = options[4], gamma = options[5], tau_min = options[6],
66+
tau_max = options[7], n_exp = options[8], eta_strategy = options[9]
67+
)
68+
69+
probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0)
70+
sol = solve(probN, alg, abstol = 1e-11)
71+
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-6)
72+
end
73+
end
74+
75+
@testitem "DFSane Termination Conditions" setup=[CoreRootfindTesting] tags=[:core] begin
76+
@testset "TC: $(nameof(typeof(termination_condition)))" for termination_condition in TERMINATION_CONDITIONS
77+
@testset "u0: $(typeof(u0))" for u0 in ([1.0, 1.0], 1.0)
78+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
79+
sol = solve(probN, DFSane(); termination_condition)
80+
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10)
81+
end
82+
end
83+
end

0 commit comments

Comments
 (0)