Skip to content

Commit f662e3e

Browse files
committed
refactor: move DFSane into NonlinearSolveSpectralMethods
1 parent a30f80b commit f662e3e

File tree

15 files changed

+568
-255
lines changed

15 files changed

+568
-255
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1919
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2020
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
2121
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
22+
NonlinearSolveSpectralMethods = "26075421-4e9a-44e1-8bd1-420ed7ad02b2"
2223
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2324
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
24-
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2525
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2626
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2727
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -99,7 +99,6 @@ PETSc = "0.2"
9999
Pkg = "1.10"
100100
PrecompileTools = "1.2"
101101
Preferences = "1.4"
102-
Printf = "1.10"
103102
Random = "1.10"
104103
ReTestItems = "1.24"
105104
RecursiveArrayTools = "3.27"

common/nlls_problem_workloads.jl

Whitespace-only changes.

common/nonlinear_problem_workloads.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using SciMLBase: NonlinearProblem, NonlinearFunction
2+
3+
nonlinear_functions = (
4+
(NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
5+
(NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
6+
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1])
7+
)
8+
9+
nonlinear_problems = NonlinearProblem[]
10+
for (fn, u0) in nonlinear_functions
11+
push!(nonlinear_problems, NonlinearProblem(fn, u0, 2.0))
12+
end

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear
2424
using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator
2525
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
2626
using StaticArraysCore: StaticArray, SMatrix, SArray, MArray
27+
using SymbolicIndexingInterface: SymbolicIndexingInterface
2728

2829
const DI = DifferentiationInterface
30+
const SII = SymbolicIndexingInterface
2931

3032
include("public.jl")
3133
include("utils.jl")
@@ -49,6 +51,8 @@ include("descent/damped_newton.jl")
4951
include("descent/dogleg.jl")
5052
include("descent/geodesic_acceleration.jl")
5153

54+
include("solve.jl")
55+
5256
# Unexported Public API
5357
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
5458
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))

lib/NonlinearSolveBase/src/abstract_types.jl

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,111 @@ concrete_jac(v::Bool) = v
218218
concrete_jac(::Val{false}) = false
219219
concrete_jac(::Val{true}) = true
220220

221+
"""
222+
AbstractNonlinearSolveCache
223+
224+
Abstract Type for all NonlinearSolveBase Caches.
225+
226+
### Interface Functions
227+
228+
- `get_fu(cache)`: get the residual.
229+
- `get_u(cache)`: get the current state.
230+
- `set_fu!(cache, fu)`: set the residual.
231+
- `has_time_limit(cache)`: whether or not the solver has a maximum time limit.
232+
- `not_terminated(cache)`: whether or not the solver has terminated.
233+
234+
- `SciMLBase.set_u!(cache, u)`: set the current state.
235+
- `SciMLBase.reinit!(cache, u0; kwargs...)`: reinitialize the cache with the initial state
236+
`u0` and any additional keyword arguments.
237+
- `SciMLBase.step!(cache; kwargs...)`: See [`SciMLBase.step!`](@ref) for more details.
238+
- `SciMLBase.isinplace(cache)`: whether or not the solver is inplace.
239+
240+
Additionally implements `SymbolicIndexingInterface` interface Functions.
241+
242+
#### Expected Fields in Sub-Types
243+
244+
For the default interface implementations we expect the following fields to be present in
245+
the cache:
246+
247+
- `fu`: the residual.
248+
- `u`: the current state.
249+
- `maxiters`: the maximum number of iterations.
250+
- `nsteps`: the number of steps taken.
251+
- `force_stop`: whether or not the solver has been forced to stop.
252+
- `retcode`: the return code.
253+
- `stats`: `NLStats` object.
254+
- `alg`: the algorithm.
255+
- `maxtime`: the maximum time limit for the solver. (Optional)
256+
- `timer`: the timer for the solver. (Optional)
257+
- `total_time`: the total time taken by the solver. (Optional)
258+
"""
221259
abstract type AbstractNonlinearSolveCache <: AbstractNonlinearSolveBaseAPI end
222260

223-
function get_u end
224-
function get_fu end
261+
get_u(cache::AbstractNonlinearSolveCache) = cache.u
262+
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu
263+
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu)
264+
SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)
265+
266+
function has_time_limit(cache::AbstractNonlinearSolveCache)
267+
maxtime = Utils.safe_getproperty(cache, Val(:maxtime))
268+
return maxtime !== missing && maxtime !== nothing
269+
end
270+
271+
function not_terminated(cache::AbstractNonlinearSolveCache)
272+
return !cache.force_stop && cache.nsteps < cache.maxiters
273+
end
274+
275+
function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache; kwargs...)
276+
return InternalAPI.reinit!(cache; kwargs...)
277+
end
278+
function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache, u0; kwargs...)
279+
return InternalAPI.reinit!(cache; u0, kwargs...)
280+
end
281+
282+
SciMLBase.isinplace(cache::AbstractNonlinearSolveCache) = SciMLBase.isinplace(cache.prob)
283+
284+
## SII Interface
285+
SII.symbolic_container(cache::AbstractNonlinearSolveCache) = cache.prob
286+
SII.parameter_values(cache::AbstractNonlinearSolveCache) = SII.parameter_values(cache.prob)
287+
SII.state_values(cache::AbstractNonlinearSolveCache) = SII.state_values(cache.prob)
288+
289+
function Base.getproperty(cache::AbstractNonlinearSolveCache, sym::Symbol)
290+
if sym === :ps
291+
!hasfield(typeof(cache), :ps) && return SII.ParameterIndexingProxy(cache)
292+
return getfield(cache, :ps)
293+
end
294+
return getfield(cache, sym)
295+
end
296+
297+
Base.getindex(cache::AbstractNonlinearSolveCache, sym) = SII.getu(cache, sym)(cache)
298+
function Base.setindex!(cache::AbstractNonlinearSolveCache, val, sym)
299+
return SII.setu(cache, sym)(cache, val)
300+
end
301+
302+
# XXX: Implement this
303+
# function Base.show(io::IO, cache::AbstractNonlinearSolveCache)
304+
# __show_cache(io, cache, 0)
305+
# end
306+
307+
# function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
308+
# println(io, "$(nameof(typeof(cache)))(")
309+
# __show_algorithm(io, cache.alg,
310+
# (" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + 4)
311+
312+
# ustr = sprint(show, get_u(cache); context = (:compact => true, :limit => true))
313+
# println(io, ",\n" * (" "^(indent + 4)) * "u = $(ustr),")
314+
315+
# residstr = sprint(show, get_fu(cache); context = (:compact => true, :limit => true))
316+
# println(io, (" "^(indent + 4)) * "residual = $(residstr),")
317+
318+
# normstr = sprint(
319+
# show, norm(get_fu(cache), Inf); context = (:compact => true, :limit => true))
320+
# println(io, (" "^(indent + 4)) * "inf-norm(residual) = $(normstr),")
321+
322+
# println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",")
323+
# println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
324+
# print(io, " "^(indent) * ")")
325+
# end
225326

226327
"""
227328
AbstractLinearSolverCache

lib/NonlinearSolveBase/src/solve.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
function SciMLBase.__solve(
2+
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...;
3+
kwargs...
4+
)
5+
cache = SciMLBase.init(prob, alg, args...; kwargs...)
6+
return SciMLBase.solve!(cache)
7+
end
8+
9+
function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
10+
while not_terminated(cache)
11+
SciMLBase.step!(cache)
12+
end
13+
14+
# The solver might have set a different `retcode`
15+
if cache.retcode == ReturnCode.Default
16+
cache.retcode = ifelse(
17+
cache.nsteps cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success
18+
)
19+
end
20+
21+
# XXX: Implement this
22+
# update_from_termination_cache!(cache.termination_cache, cache)
23+
24+
update_trace!(
25+
cache.trace, cache.nsteps, get_u(cache), get_fu(cache), nothing, nothing, nothing;
26+
last = Val(true)
27+
)
28+
29+
return SciMLBase.build_solution(
30+
cache.prob, cache.alg, get_u(cache), get_fu(cache);
31+
cache.retcode, cache.stats, cache.trace
32+
)
33+
end
34+
35+
"""
36+
step!(cache::AbstractNonlinearSolveCache;
37+
recompute_jacobian::Union{Nothing, Bool} = nothing)
38+
39+
Performs one step of the nonlinear solver.
40+
41+
### Keyword Arguments
42+
43+
- `recompute_jacobian`: allows controlling whether the jacobian is recomputed at the
44+
current step. If `nothing`, then the algorithm determines whether to recompute the
45+
jacobian. If `true` or `false`, then the jacobian is recomputed or not recomputed,
46+
respectively. For algorithms that don't use jacobian information, this keyword is
47+
ignored with a one-time warning.
48+
"""
49+
function SciMLBase.step!(cache::AbstractNonlinearSolveCache, args...; kwargs...)
50+
not_terminated(cache) || return
51+
52+
has_time_limit(cache) && (time_start = time())
53+
54+
res = @static_timeit cache.timer "solve" begin
55+
InternalAPI.step!(cache, args...; kwargs...)
56+
end
57+
58+
cache.stats.nsteps += 1
59+
cache.nsteps += 1
60+
61+
if has_time_limit(cache)
62+
cache.total_time += time() - time_start
63+
64+
if !cache.force_stop && cache.retcode == ReturnCode.Default &&
65+
cache.total_time cache.maxtime
66+
cache.retcode = ReturnCode.MaxTime
67+
cache.force_stop = true
68+
end
69+
end
70+
71+
return res
72+
end
73+
74+
# Some algorithms don't support creating a cache and doing `solve!`, this unfortunately
75+
# makes it difficult to write generic code that supports caching. For the algorithms that
76+
# don't have a `__init` function defined, we create a "Fake Cache", which just calls
77+
# `__solve` from `solve!`
78+
# Warning: This doesn't implement all the necessary interface functions
79+
@concrete mutable struct NonlinearSolveNoInitCache <: AbstractNonlinearSolveCache
80+
prob
81+
alg
82+
args
83+
kwargs::Any
84+
end
85+
86+
function SciMLBase.reinit!(
87+
cache::NonlinearSolveNoInitCache, u0 = cache.prob.u0; p = cache.prob.p, kwargs...
88+
)
89+
cache.prob = SciMLBase.remake(cache.prob; u0, p)
90+
cache.kwargs = merge(cache.kwargs, kwargs)
91+
return cache
92+
end
93+
94+
function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolveNoInitCache)
95+
print(io, "NonlinearSolveNoInitCache(alg = $(cache.alg))")
96+
end
97+
98+
function SciMLBase.__init(
99+
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...;
100+
kwargs...
101+
)
102+
return NonlinearSolveNoInitCache(prob, alg, args, kwargs)
103+
end
104+
105+
function CommonSolve.solve!(cache::NonlinearSolveNoInitCache)
106+
return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...)
107+
end

lib/NonlinearSolveBase/src/utils.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using SciMLOperators: AbstractSciMLOperator
99
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearFunction
1010
using StaticArraysCore: StaticArray, SArray
1111

12-
using ..NonlinearSolveBase: L2_NORM, Linf_NORM
12+
using ..NonlinearSolveBase: NonlinearSolveBase, L2_NORM, Linf_NORM
1313

1414
is_extension_loaded(::Val) = false
1515

@@ -145,6 +145,26 @@ function evaluate_f!!(f::NonlinearFunction, fu, u, p)
145145
return f(u, p)
146146
end
147147

148+
function evaluate_f(prob::AbstractNonlinearProblem, u)
149+
if SciMLBase.isinplace(prob)
150+
fu = prob.f.resid_prototype === nothing ? similar(u) :
151+
similar(prob.f.resid_prototype)
152+
prob.f(fu, u, prob.p)
153+
else
154+
fu = prob.f(u, prob.p)
155+
end
156+
return fu
157+
end
158+
159+
function evaluate_f!(cache, u, p)
160+
cache.stats.nf += 1
161+
if SciMLBase.isinplace(cache)
162+
cache.prob.f(NonlinearSolveBase.get_fu(cache), u, p)
163+
else
164+
NonlinearSolveBase.set_fu!(cache, cache.prob.f(u, p))
165+
end
166+
end
167+
148168
function make_sparse end
149169

150170
condition_number(J::AbstractMatrix) = cond(J)

lib/NonlinearSolveSpectralMethods/Project.toml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,36 @@ uuid = "26075421-4e9a-44e1-8bd1-420ed7ad02b2"
33
authors = ["Avik Pal <[email protected]> and contributors"]
44
version = "1.0.0"
55

6+
[deps]
7+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
8+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
9+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
10+
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
11+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
13+
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
14+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
15+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
16+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
17+
618
[compat]
719
Aqua = "0.8"
20+
CommonSolve = "0.2.4"
21+
ConcreteStructs = "0.2.3"
22+
DiffEqBase = "6.155.3"
823
ExplicitImports = "1.5"
924
Hwloc = "3"
1025
InteractiveUtils = "<0.0.1, 1"
26+
LineSearch = "0.1.4"
27+
LinearAlgebra = "1.11.0"
28+
MaybeInplace = "0.1.4"
1129
NonlinearProblemLibrary = "0.1.2"
30+
NonlinearSolveBase = "1.1"
1231
Pkg = "1.10"
32+
PrecompileTools = "1.2"
1333
ReTestItems = "1.24"
34+
Reexport = "1"
35+
SciMLBase = "2.54"
1436
StableRNGs = "1"
1537
Test = "1.10"
1638
julia = "1.10"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,41 @@
11
module NonlinearSolveSpectralMethods
22

3+
using Reexport: @reexport
4+
using PrecompileTools: @compile_workload, @setup_workload
5+
6+
using CommonSolve: CommonSolve
7+
using ConcreteStructs: @concrete
8+
using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches
9+
using LinearAlgebra: dot
10+
using LineSearch: RobustNonMonotoneLineSearch
11+
using MaybeInplace: @bb
12+
using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
13+
AbstractNonlinearSolveCache, Utils, InternalAPI, get_timer_output,
14+
@static_timeit, update_trace!
15+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode
16+
17+
include("dfsane.jl")
18+
19+
include("solve.jl")
20+
21+
@setup_workload begin
22+
include(joinpath(
23+
@__DIR__, "..", "..", "..", "common", "nonlinear_problem_workloads.jl"
24+
))
25+
26+
algs = [DFSane()]
27+
28+
@compile_workload begin
29+
@sync begin
30+
for prob in nonlinear_problems, alg in algs
31+
Threads.@spawn CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
32+
end
33+
end
34+
end
35+
end
36+
37+
@reexport using SciMLBase, NonlinearSolveBase
38+
39+
export GeneralizedDFSane, DFSane
40+
341
end

0 commit comments

Comments
 (0)