Skip to content

Commit 5292a61

Browse files
Merge pull request #382 from SciML/ap/fake_caching
Add caching for solvers without init
2 parents bf072d2 + e37b70a commit 5292a61

File tree

9 files changed

+93
-20
lines changed

9 files changed

+93
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.6.0"
4+
version = "3.7.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/NonlinearSolveLeastSquaresOptimExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ end
2323
kwargs
2424
end
2525

26+
function Base.show(io::IO, cache::LeastSquaresOptimJLCache)
27+
print(io, "LeastSquaresOptimJLCache()")
28+
end
29+
2630
function SciMLBase.reinit!(cache::LeastSquaresOptimJLCache, args...; kwargs...)
2731
error("Reinitialization not supported for LeastSquaresOptimJL.")
2832
end

ext/NonlinearSolveNLSolversExt.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ using ADTypes, FastClosures, NonlinearSolve, NLSolvers, SciMLBase, LinearAlgebra
44
using FiniteDiff, ForwardDiff
55

66
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
7-
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false,
8-
termination_condition = nothing, kwargs...)
7+
abstol = nothing, reltol = nothing, maxiters = 1000,
8+
alias_u0::Bool = false, termination_condition = nothing, kwargs...)
99
NonlinearSolve.__test_termination_condition(termination_condition, :NLSolversJL)
1010

1111
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0))
@@ -50,12 +50,13 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
5050
prob_nlsolver = NEqProblem(prob_obj; inplace = false)
5151
res = NLSolvers.solve(prob_nlsolver, prob.u0, alg.method, options)
5252

53-
retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
54-
ReturnCode.MaxIters)
53+
retcode = ifelse(norm(res.info.best_residual, Inf) abstol,
54+
ReturnCode.Success, ReturnCode.MaxIters)
5555
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)
5656

57-
return SciMLBase.build_solution(prob, alg, res.info.solution,
58-
res.info.best_residual; retcode, original = res, stats)
57+
return SciMLBase.build_solution(
58+
prob, alg, res.info.solution, res.info.best_residual;
59+
retcode, original = res, stats)
5960
end
6061

6162
f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
@@ -73,12 +74,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
7374

7475
res = NLSolvers.solve(prob_nlsolver, u0, alg.method, options)
7576

76-
retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
77-
ReturnCode.MaxIters)
77+
retcode = ifelse(
78+
norm(res.info.best_residual, Inf) abstol, ReturnCode.Success, ReturnCode.MaxIters)
7879
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)
7980

80-
return SciMLBase.build_solution(prob, alg, res.info.solution,
81-
res.info.best_residual; retcode, original = res, stats)
81+
return SciMLBase.build_solution(prob, alg, res.info.solution, res.info.best_residual;
82+
retcode, original = res, stats)
8283
end
8384

8485
end

src/NonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ include("core/generic.jl")
6666
include("core/approximate_jacobian.jl")
6767
include("core/generalized_first_order.jl")
6868
include("core/spectral_methods.jl")
69+
include("core/noinit.jl")
6970

7071
include("algorithms/raphson.jl")
7172
include("algorithms/pseudo_transient.jl")

src/abstract_types.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ get_u(cache::AbstractNonlinearSolveCache) = cache.u
214214
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu)
215215
SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)
216216

217+
function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache; kwargs...)
218+
return reinit_cache!(cache; kwargs...)
219+
end
217220
function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache, u0; kwargs...)
218221
return reinit_cache!(cache; u0, kwargs...)
219222
end

src/core/noinit.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Some algorithms don't support creating a cache and doing `solve!`, this unfortunately
2+
# makes it difficult to write generic code that supports caching. For the algorithms that
3+
# don't have a `__init` function defined, we create a "Fake Cache", which just calls
4+
# `__solve` from `solve!`
5+
@concrete mutable struct NonlinearSolveNoInitCache{iip, timeit} <:
6+
AbstractNonlinearSolveCache{iip, timeit}
7+
prob
8+
alg
9+
args
10+
kwargs::Any
11+
end
12+
13+
function SciMLBase.reinit!(
14+
cache::NonlinearSolveNoInitCache, u0 = cache.prob.u0; p = cache.prob.p, kwargs...)
15+
prob = remake(cache.prob; u0, p)
16+
cache.prob = prob
17+
cache.kwargs = merge(cache.kwargs, kwargs)
18+
return cache
19+
end
20+
21+
function Base.show(io::IO, cache::NonlinearSolveNoInitCache)
22+
print(io, "NonlinearSolveNoInitCache(alg = $(cache.alg))")
23+
end
24+
25+
function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
26+
alg::Union{AbstractNonlinearSolveAlgorithm,
27+
SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm},
28+
args...;
29+
maxtime = nothing,
30+
kwargs...) where {uType, iip}
31+
return NonlinearSolveNoInitCache{iip, maxtime !== nothing}(
32+
prob, alg, args, merge((; maxtime), kwargs))
33+
end
34+
35+
function SciMLBase.solve!(cache::NonlinearSolveNoInitCache)
36+
return solve(cache.prob, cache.alg, cache.args...; cache.kwargs...)
37+
end

test/misc/noinit_caching_tests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@testitem "NoInit Caching" begin
2+
using LinearAlgebra
3+
import NLsolve, NLSolvers
4+
5+
solvers = [SimpleNewtonRaphson(), SimpleTrustRegion(), SimpleDFSane(), NLsolveJL(),
6+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking()))]
7+
8+
prob = NonlinearProblem((u, p) -> u .^ 2 .- p, [0.1, 0.3], 2.0)
9+
10+
for alg in solvers
11+
cache = init(prob, alg)
12+
sol = solve!(cache)
13+
@test SciMLBase.successful_retcode(sol)
14+
@test norm(sol.resid, Inf) 1e-6
15+
16+
reinit!(cache; p = 5.0)
17+
@test cache.prob.p == 5.0
18+
sol = solve!(cache)
19+
@test SciMLBase.successful_retcode(sol)
20+
@test norm(sol.resid, Inf) 1e-6
21+
@test norm(sol.u .^ 2 .- 5.0, Inf) 1e-6
22+
end
23+
end

test/misc/qa_tests.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
@testitem "Aqua" begin
2-
using NonlinearSolve, Aqua
2+
using NonlinearSolve, SimpleNonlinearSolve, Aqua
33

44
Aqua.find_persistent_tasks_deps(NonlinearSolve)
55
Aqua.test_ambiguities(NonlinearSolve; recursive = false)
66
Aqua.test_deps_compat(NonlinearSolve)
7-
Aqua.test_piracies(
8-
NonlinearSolve, treat_as_own = [NonlinearProblem, NonlinearLeastSquaresProblem])
7+
Aqua.test_piracies(NonlinearSolve,
8+
treat_as_own = [NonlinearProblem, NonlinearLeastSquaresProblem,
9+
SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm])
910
Aqua.test_project_extras(NonlinearSolve)
1011
# Timer Outputs needs to be enabled via Preferences
1112
Aqua.test_stale_deps(NonlinearSolve; ignore = [:TimerOutputs])

test/wrappers/rootfind_tests.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ end
1616
prob_iip = SteadyStateProblem(f_iip, u0)
1717

1818
for alg in [
19-
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
19+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
20+
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
2021
sol = solve(prob_iip, alg)
2122
@test SciMLBase.successful_retcode(sol.retcode)
2223
@test maximum(abs, sol.resid) < 1e-6
@@ -28,7 +29,8 @@ end
2829
prob_oop = SteadyStateProblem(f_oop, u0)
2930

3031
for alg in [
31-
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
32+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
33+
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
3234
sol = solve(prob_oop, alg)
3335
@test SciMLBase.successful_retcode(sol.retcode)
3436
@test maximum(abs, sol.resid) < 1e-6
@@ -45,7 +47,8 @@ end
4547
prob_iip = NonlinearProblem{true}(f_iip, u0)
4648

4749
for alg in [
48-
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
50+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
51+
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
4952
local sol
5053
sol = solve(prob_iip, alg)
5154
@test SciMLBase.successful_retcode(sol.retcode)
@@ -57,7 +60,8 @@ end
5760
u0 = zeros(2)
5861
prob_oop = NonlinearProblem{false}(f_oop, u0)
5962
for alg in [
60-
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
63+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
64+
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
6165
local sol
6266
sol = solve(prob_oop, alg)
6367
@test SciMLBase.successful_retcode(sol.retcode)
@@ -70,8 +74,7 @@ end
7074
for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15],
7175
alg in [
7276
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
73-
NLsolveJL(),
74-
CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
77+
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
7578
SIAMFANLEquationsJL(; method = :pseudotransient),
7679
SIAMFANLEquationsJL(; method = :secant)]
7780

0 commit comments

Comments
 (0)