Skip to content

Commit 4baf777

Browse files
committed
Handle termination conditions in reinit!
1 parent eddcfde commit 4baf777

File tree

4 files changed

+30
-5
lines changed

4 files changed

+30
-5
lines changed

src/dfsane.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ function SciMLBase.solve!(cache::DFSaneCache)
315315
end
316316

317317
function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
318-
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
318+
abstol = cache.abstol, termination_condition = cache.termination_condition,
319+
maxiters = cache.maxiters) where {iip}
319320
cache.p = p
320321
if iip
321322
recursivecopy!(cache.uₙ, u0)
@@ -336,7 +337,14 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p
336337
T = eltype(cache.uₙ)
337338
cache.σₙ = T(cache.alg.σ_1)
338339

340+
termination_condition = _get_reinit_termination_condition(cache,
341+
abstol,
342+
reltol,
343+
termination_condition)
344+
339345
cache.abstol = abstol
346+
cache.reltol = reltol
347+
cache.termination_condition = termination_condition
340348
cache.maxiters = maxiters
341349
cache.stats.nf = 1
342350
cache.stats.nsteps = 1

src/gaussnewton.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ function perform_step!(cache::GaussNewtonCache{false})
181181
end
182182

183183
function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p,
184-
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
184+
abstol = cache.abstol, reltol = cache.reltol,
185+
termination_condition = cache.termination_condition,
186+
maxiters = cache.maxiters) where {iip}
185187
cache.p = p
186188
if iip
187189
recursivecopy!(cache.u, u0)
@@ -191,7 +193,14 @@ function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache
191193
cache.u = u0
192194
cache.fu1 = cache.f(cache.u, p)
193195
end
196+
termination_condition = _get_reinit_termination_condition(cache,
197+
abstol,
198+
reltol,
199+
termination_condition)
200+
194201
cache.abstol = abstol
202+
cache.reltol = reltol
203+
cache.termination_condition = termination_condition
195204
cache.maxiters = maxiters
196205
cache.stats.nf = 1
197206
cache.stats.nsteps = 1

src/trustRegion.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,8 @@ end
724724
get_fu(cache::TrustRegionCache) = cache.fu
725725

726726
function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache.p,
727-
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
727+
abstol = cache.abstol, termination_condition = cache.termination_condition,
728+
maxiters = cache.maxiters) where {iip}
728729
cache.p = p
729730
if iip
730731
recursivecopy!(cache.u, u0)
@@ -734,7 +735,14 @@ function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache
734735
cache.u = u0
735736
cache.fu = cache.f(cache.u, p)
736737
end
738+
termination_condition = _get_reinit_termination_condition(cache,
739+
abstol,
740+
reltol,
741+
termination_condition)
742+
737743
cache.abstol = abstol
744+
cache.reltol = reltol
745+
cache.termination_condition = termination_condition
738746
cache.maxiters = maxiters
739747
cache.stats.nf = 1
740748
cache.stats.nsteps = 1

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,13 @@ function _get_reinit_termination_condition(cache, abstol, reltol, termination_co
244244
if termination_condition != cache.termination_condition
245245
if abstol != cache.abstol
246246
if abstol != termination_condition.abstol
247-
error("Incompatible absolute tolerances found")
247+
error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.")
248248
end
249249
end
250250

251251
if reltol != cache.reltol
252252
if reltol != termination_condition.reltol
253-
error("Incompatible relative tolerances found")
253+
error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.")
254254
end
255255
end
256256
termination_condition

0 commit comments

Comments
 (0)