Skip to content

Commit bddcff5

Browse files
Merge pull request #90 from SciML/ChrisRackauckas-patch-1
Fix tests
2 parents e6cbcd1 + a693e80 commit bddcff5

File tree

4 files changed

+14
-31
lines changed

4 files changed

+14
-31
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
88
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1011
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1112
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1213

src/SteadyStateDiffEq.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
module SteadyStateDiffEq
22

33
using Reexport: @reexport
4-
@reexport using DiffEqBase
4+
@reexport using SciMLBase
55

66
using ConcreteStructs: @concrete
77
using NonlinearSolveBase
8+
import DiffEqBase
89
using NonlinearSolveBase: AbstractNonlinearTerminationMode,
910
AbstractSafeNonlinearTerminationMode,
1011
AbstractSafeBestNonlinearTerminationMode,
11-
NonlinearSafeTerminationReturnCode, NormTerminationMode
12+
NormTerminationMode
1213
using DiffEqCallbacks: TerminateSteadyState
1314
using LinearAlgebra: norm
1415
using SciMLBase: SciMLBase, CallbackSet, NonlinearProblem, ODEProblem,

src/solve.jl

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515

1616
function SciMLBase.__solve(prob::SciMLBase.AbstractSteadyStateProblem, alg::DynamicSS,
1717
args...; abstol = 1e-8, reltol = 1e-6, odesolve_kwargs = (;),
18-
save_idxs = nothing, termination_condition = NormTerminationMode(infnorm),
18+
save_idxs = nothing, termination_condition = NonlinearSolveBase.NormTerminationMode(infnorm),
1919
kwargs...)
2020
tspan = __get_tspan(prob.u0, alg)
2121

@@ -36,9 +36,9 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractSteadyStateProblem, alg::Dyna
3636
du = f(prob.u0, prob.p, first(tspan))
3737
end
3838

39-
tc_cache = init(du, prob.u0, termination_condition, last(tspan); abstol, reltol)
40-
abstol = DiffEqBase.get_abstol(tc_cache)
41-
reltol = DiffEqBase.get_reltol(tc_cache)
39+
tc_cache = init(prob, termination_condition, du, prob.u0; abstol, reltol)
40+
abstol = NonlinearSolveBase.get_abstol(tc_cache)
41+
reltol = NonlinearSolveBase.get_reltol(tc_cache)
4242

4343
function terminate_function(u, t, integrator)
4444
return tc_cache(get_du(integrator), integrator.u, integrator.uprev, t)
@@ -78,16 +78,7 @@ end
7878
function __get_result_from_sol(::AbstractSafeNonlinearTerminationMode, tc_cache, odesol)
7979
u, t = last(odesol.u), last(odesol.t)
8080
du = odesol(t, Val{1})
81-
82-
if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
83-
retcode_tc = ReturnCode.Success
84-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
85-
retcode_tc = ReturnCode.ConvergenceFailure
86-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
87-
retcode_tc = ReturnCode.Unstable
88-
else
89-
retcode_tc = ReturnCode.Default
90-
end
81+
retcode_tc = tc_cache.retcode
9182

9283
retcode = if odesol.retcode == ReturnCode.Terminated
9384
ifelse(retcode_tc != ReturnCode.Default, retcode_tc, ReturnCode.Success)
@@ -103,16 +94,7 @@ end
10394
function __get_result_from_sol(::AbstractSafeBestNonlinearTerminationMode, tc_cache, odesol)
10495
u, t = tc_cache.u, only(DiffEqBase.get_saved_values(tc_cache))
10596
du = odesol(t, Val{1})
106-
107-
if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
108-
retcode_tc = ReturnCode.Success
109-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
110-
retcode_tc = ReturnCode.ConvergenceFailure
111-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
112-
retcode_tc = ReturnCode.Unstable
113-
else
114-
retcode_tc = ReturnCode.Default
115-
end
97+
retcode_tc = tc_cache.retcode
11698

11799
retcode = if odesol.retcode == ReturnCode.Terminated
118100
ifelse(retcode_tc != ReturnCode.Default, retcode_tc, ReturnCode.Success)

test/core.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using SteadyStateDiffEq,
2-
DiffEqBase, NonlinearSolve, Sundials, OrdinaryDiffEq, DiffEqCallbacks, Test
1+
using SteadyStateDiffEq, NonlinearSolve, Sundials, OrdinaryDiffEq, DiffEqCallbacks, Test
2+
using NonlinearSolve.NonlinearSolveBase
3+
using NonlinearSolve.NonlinearSolveBase: NormTerminationMode, RelTerminationMode, RelNormTerminationMode,
4+
AbsTerminationMode, AbsNormTerminationMode
35

46
function f(du, u, p, t)
57
du[1] = 2 - 2u[1]
@@ -83,9 +85,6 @@ sol2 = solve(prob, DynamicSS(Tsit5()); abstol = 1e-4)
8385
for termination_condition in [
8486
NormTerminationMode(SteadyStateDiffEq.infnorm), RelTerminationMode(), RelNormTerminationMode(SteadyStateDiffEq.infnorm),
8587
AbsTerminationMode(), AbsNormTerminationMode(SteadyStateDiffEq.infnorm),
86-
RelSafeTerminationMode(SteadyStateDiffEq.infnorm),
87-
AbsSafeTerminationMode(SteadyStateDiffEq.infnorm), RelSafeBestTerminationMode(SteadyStateDiffEq.infnorm),
88-
AbsSafeBestTerminationMode(SteadyStateDiffEq.infnorm)
8988
]
9089
sol_tc = solve(prob, DynamicSS(Tsit5()); termination_condition)
9190
@show sol_tc.retcode, termination_condition

0 commit comments

Comments
 (0)