Skip to content

Commit e9b7146

Browse files
committed
Use the new Termination Condition API
1 parent 12d9743 commit e9b7146

File tree

8 files changed

+1009
-60
lines changed

8 files changed

+1009
-60
lines changed

Manifest.toml

Lines changed: 956 additions & 0 deletions
Large diffs are not rendered by default.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ ADTypes = "0.2"
4040
ArrayInterface = "6.0.24, 7"
4141
BandedMatrices = "1"
4242
ConcreteStructs = "0.2"
43-
DiffEqBase = "6.130"
43+
DiffEqBase = "6.136"
4444
EnumX = "1"
4545
Enzyme = "0.11"
4646
FastBroadcast = "0.1.9, 0.2"
@@ -56,7 +56,7 @@ RecursiveArrayTools = "2"
5656
Reexport = "0.2, 1"
5757
SciMLBase = "2.4"
5858
SimpleNonlinearSolve = "0.1.23"
59-
SparseDiffTools = "2.6"
59+
SparseDiffTools = "2.9"
6060
StaticArraysCore = "1.4"
6161
UnPack = "1.0"
6262
Zygote = "0.6"

src/NonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ PrecompileTools.@recompile_invalidations begin
3030
end
3131

3232
@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
33+
import DiffEqBase: AbstractNonlinearTerminationMode
3334

3435
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
3536
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}

src/broyden.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252
reset_check
5353
prob
5454
stats::NLStats
55-
lscache
55+
ls_cache
5656
termination_condition
5757
tc_storage
5858
end
@@ -93,7 +93,7 @@ function perform_step!(cache::GeneralBroydenCache{true})
9393
T = eltype(u)
9494

9595
mul!(_vec(du), J⁻¹, _vec(fu))
96-
α = perform_linesearch!(cache.lscache, u, du)
96+
α = perform_linesearch!(cache.ls_cache, u, du)
9797
_axpy!(-α, du, u)
9898
f(fu2, u, p)
9999

@@ -137,7 +137,7 @@ function perform_step!(cache::GeneralBroydenCache{false})
137137
T = eltype(cache.u)
138138

139139
cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu))
140-
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
140+
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
141141
cache.u = cache.u .- α * cache.du
142142
cache.fu2 = f(cache.u, p)
143143

src/klement.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ end
6161
reltol
6262
prob
6363
stats::NLStats
64-
lscache
64+
ls_cache
6565
termination_condition
6666
tc_storage
6767
end
@@ -131,7 +131,7 @@ function perform_step!(cache::GeneralKlementCache{true})
131131
cache.linsolve = linres.cache
132132

133133
# Line Search
134-
α = perform_linesearch!(cache.lscache, u, du)
134+
α = perform_linesearch!(cache.ls_cache, u, du)
135135
_axpy!(-α, du, u)
136136
f(cache.fu2, u, p)
137137

@@ -193,7 +193,7 @@ function perform_step!(cache::GeneralKlementCache{false})
193193
end
194194

195195
# Line Search
196-
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
196+
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
197197
cache.u = @. cache.u - α * cache.du # `u` might not support mutation
198198
cache.fu2 = f(cache.u, p)
199199

src/lbroyden.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
reset_check
6060
prob
6161
stats::NLStats
62-
lscache
62+
ls_cache
6363
termination_condition
6464
tc_storage
6565
end
@@ -109,7 +109,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
109109

110110
termination_condition = cache.termination_condition(tc_storage)
111111

112-
α = perform_linesearch!(cache.lscache, u, du)
112+
α = perform_linesearch!(cache.ls_cache, u, du)
113113
_axpy!(-α, du, u)
114114
f(cache.fu2, u, p)
115115

@@ -169,7 +169,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false})
169169

170170
T = eltype(cache.u)
171171

172-
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
172+
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
173173
cache.u = cache.u .- α * cache.du
174174
cache.fu2 = f(cache.u, p)
175175

src/raphson.jl

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ end
7070
reltol
7171
prob
7272
stats::NLStats
73-
lscache
74-
termination_condition
75-
tc_storage
73+
ls_cache
74+
tc_cache
7675
end
7776

7877
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...;
@@ -86,39 +85,34 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
8685
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
8786
linsolve_kwargs)
8887

89-
abstol, reltol, termination_condition = _init_termination_elements(abstol,
90-
reltol, termination_condition, eltype(u))
91-
92-
mode = DiffEqBase.get_termination_mode(termination_condition)
93-
94-
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
95-
nothing
88+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, u,
89+
termination_condition)
9690

9791
return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J,
9892
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
9993
NLStats(1, 0, 0, 0, 0),
100-
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)),
101-
termination_condition, storage)
94+
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), tc_cache)
10295
end
10396

10497
function perform_step!(cache::NewtonRaphsonCache{true})
105-
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du, tc_storage = cache
98+
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du = cache
10699
jacobian!!(J, cache)
107100

108-
termination_condition = cache.termination_condition(tc_storage)
109-
110101
# u = u - J \ fu
111102
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
112103
p, reltol = cache.abstol)
113104
cache.linsolve = linres.cache
114105

115106
# Line Search
116-
α = perform_linesearch!(cache.lscache, u, du)
107+
α = perform_linesearch!(cache.ls_cache, u, du)
117108
_axpy!(-α, du, u)
118109
f(cache.fu1, u, p)
119110

120-
termination_condition(cache.fu1, u, u_prev, cache.abstol, cache.reltol) &&
121-
(cache.force_stop = true)
111+
if cache.tc_cache(cache.fu1, cache.u, u_prev)
112+
# Stores the best solution in cache!
113+
cache.tc_cache.u !== nothing && copyto!(cache.u, cache.tc_cache.u)
114+
cache.force_stop = true
115+
end
122116

123117
@. u_prev = u
124118
cache.stats.nf += 1
@@ -129,9 +123,7 @@ function perform_step!(cache::NewtonRaphsonCache{true})
129123
end
130124

131125
function perform_step!(cache::NewtonRaphsonCache{false})
132-
@unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache
133-
134-
termination_condition = cache.termination_condition(tc_storage)
126+
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache
135127

136128
cache.J = jacobian!!(cache.J, cache)
137129
# u = u - J \ fu
@@ -144,14 +136,17 @@ function perform_step!(cache::NewtonRaphsonCache{false})
144136
end
145137

146138
# Line Search
147-
α = perform_linesearch!(cache.lscache, u, cache.du)
139+
α = perform_linesearch!(cache.ls_cache, u, cache.du)
148140
cache.u = @. u - α * cache.du # `u` might not support mutation
149141
cache.fu1 = f(cache.u, p)
150142

151-
termination_condition(cache.fu1, cache.u, u_prev, cache.abstol, cache.reltol) &&
152-
(cache.force_stop = true)
143+
if cache.tc_cache(cache.fu1, cache.u, u_prev)
144+
# Stores the best solution in cache!
145+
cache.tc_cache.u !== nothing && (cache.u = cache.tc_cache.u)
146+
cache.force_stop = true
147+
end
153148

154-
cache.u_prev = @. cache.u
149+
cache.u_prev = cache.u
155150
cache.stats.nf += 1
156151
cache.stats.njacs += 1
157152
cache.stats.nsolve += 1
@@ -161,7 +156,7 @@ end
161156

162157
function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p,
163158
abstol = cache.abstol, reltol = cache.reltol,
164-
termination_condition = cache.termination_condition,
159+
termination_condition = get_termination_mode(cache.tc_cache),
165160
maxiters = cache.maxiters) where {iip}
166161
cache.p = p
167162
if iip
@@ -173,12 +168,12 @@ function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cac
173168
cache.fu1 = cache.f(cache.u, p)
174169
end
175170

176-
termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
171+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.u,
177172
termination_condition)
178173

179174
cache.abstol = abstol
180175
cache.reltol = reltol
181-
cache.termination_condition = termination_condition
176+
cache.tc_cache = tc_cache
182177
cache.maxiters = maxiters
183178
cache.stats.nf = 1
184179
cache.stats.nsteps = 1

src/utils.jl

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,4 @@
1-
2-
@inline UNITLESS_ABS2(x) = real(abs2(x))
3-
@inline DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u)
4-
@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}
5-
return sqrt(real(sum(abs2, u)) / length(u))
6-
end
7-
@inline function DEFAULT_NORM(u::StaticArray{<:Union{AbstractFloat, Complex}})
8-
return sqrt(real(sum(abs2, u)) / length(u))
9-
end
10-
@inline function DEFAULT_NORM(u::AbstractVectorOfArray)
11-
return sum(sqrt(real(sum(UNITLESS_ABS2, _u)) / length(_u)) for _u in u.u)
12-
end
13-
@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2, u)) / length(u))
14-
@inline DEFAULT_NORM(u) = norm(u)
1+
const DEFAULT_NORM = DiffEqBase.NONLINEARSOLVE_DEFAULT_NORM
152

163
# Ignores NaN
174
function __findmin(f, x)
@@ -36,7 +23,7 @@ code.
3623
`autodiff=<ADTypes>`.
3724
"""
3825
function default_adargs_to_adtype(; chunk_size = missing, autodiff = nothing,
39-
standardtag = missing, diff_type = missing)
26+
standardtag = missing, diff_type = missing)
4027
# If using the new API short circuit
4128
autodiff === nothing && return nothing
4229
autodiff isa ADTypes.AbstractADType && return autodiff
@@ -89,8 +76,8 @@ end
8976
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing
9077

9178
function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
92-
du = nothing, u = nothing, p = nothing, t = nothing, weight = nothing,
93-
cachedata = nothing, reltol = nothing) where {P}
79+
du = nothing, u = nothing, p = nothing, t = nothing, weight = nothing,
80+
cachedata = nothing, reltol = nothing) where {P}
9481
A !== nothing && (linsolve.A = A)
9582
b !== nothing && (linsolve.b = b)
9683
linu !== nothing && (linsolve.u = linu)
@@ -167,7 +154,7 @@ _maybe_mutable(x, _) = x
167154

168155
# Helper function to get value of `f(u, p)`
169156
function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
170-
NonlinearLeastSquaresProblem{uType, iip}}, u) where {uType, iip}
157+
NonlinearLeastSquaresProblem{uType, iip}}, u) where {uType, iip}
171158
@unpack f, u0, p = prob
172159
if iip
173160
fu = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
@@ -194,7 +181,7 @@ end
194181
195182
Defaults to `mul!(C, A, B)`. However, for sparse matrices uses `C .= A * B`.
196183
"""
197-
__matmul!(C, A, B) = mul!(C, A, B)
184+
__matmul!(C, A, B) = mul!(C, A, B)``
198185
__matmul!(C::AbstractSparseMatrix, A, B) = C .= A * B
199186

200187
# Concretize Algorithms
@@ -216,11 +203,20 @@ function __get_concrete_algorithm(alg, prob)
216203
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
217204
else
218205
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(;
219-
tag = NonlinearSolveTag())
206+
tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0)))
220207
end
221208
return set_ad(alg, ad)
222209
end
223210

211+
function init_termination_cache(abstol, reltol, u, ::Nothing)
212+
return init_termination_cache(abstol, reltol, u, AbsNormTerminationMode())
213+
end
214+
function init_termination_cache(abstol, reltol, u, tc::AbstractNonlinearTerminationMode)
215+
tc_cache = init(u, tc; abstol, reltol)
216+
return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache
217+
end
218+
219+
# FIXME: Remove the functions below when we have migrated to the new type stable API
224220
__cvt_real(::Type{T}, ::Nothing) where {T} = nothing
225221
__cvt_real(::Type{T}, x) where {T} = real(T(x))
226222

@@ -231,7 +227,7 @@ function _get_tolerance(η, tc_η, ::Type{T}) where {T}
231227
end
232228

233229
function _init_termination_elements(abstol, reltol, termination_condition,
234-
::Type{T}; mode = NLSolveTerminationMode.AbsNorm) where {T}
230+
::Type{T}; mode = NLSolveTerminationMode.AbsNorm) where {T}
235231
if termination_condition !== nothing
236232
if abstol !== nothing && abstol != termination_condition.abstol
237233
error("Incompatible absolute tolerances found. The tolerances supplied as the \
@@ -279,6 +275,7 @@ function _get_reinit_termination_condition(cache, abstol, reltol, termination_co
279275
termination_condition.safe_termination_options)
280276
end
281277
end
278+
# FIXME: Purge things till here!
282279

283280
__init_identity_jacobian(u::Number, _) = u
284281
function __init_identity_jacobian(u, fu)

0 commit comments

Comments
 (0)