Skip to content

Commit 58e9215

Browse files
Merge pull request #2679 from vyudu/sids
SimpleImplicitDiscreteSolve
2 parents 7a4a5fb + 1116177 commit 58e9215

File tree

6 files changed

+212
-1
lines changed

6 files changed

+212
-1
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ jobs:
5757
- OrdinaryDiffEqVerner
5858

5959
- ImplicitDiscreteSolve
60+
- SimpleImplicitDiscreteSolve
6061

6162
version:
6263
- 'lts'
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 SciML
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name = "SimpleImplicitDiscreteSolve"
2+
uuid = "8b67ef88-54bd-43ff-aca0-8be02588656a"
3+
authors = ["vyudu <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
9+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
10+
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
11+
12+
[compat]
13+
DiffEqBase = "6.164.1"
14+
OrdinaryDiffEqSDIRK = "1.2.0"
15+
Reexport = "1.2.2"
16+
SciMLBase = "2.74.1"
17+
SimpleNonlinearSolve = "2.1.0"
18+
Test = "1.10"
19+
20+
[extras]
21+
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
22+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
24+
[targets]
25+
test = ["OrdinaryDiffEqSDIRK", "Test"]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
module SimpleImplicitDiscreteSolve
2+
3+
using SciMLBase
4+
using SimpleNonlinearSolve
5+
using Reexport
6+
@reexport using DiffEqBase
7+
8+
"""
9+
SimpleIDSolve()
10+
11+
Simple solver for `ImplicitDiscreteSystems`. Uses `SimpleNewtonRaphson` to solve for the next state at every timestep.
12+
"""
13+
struct SimpleIDSolve <: SciMLBase.AbstractODEAlgorithm end
14+
15+
function DiffEqBase.__init(prob::ImplicitDiscreteProblem, alg::SimpleIDSolve; dt = 1)
16+
u0 = prob.u0
17+
p = prob.p
18+
f = prob.f
19+
t = prob.tspan[1]
20+
21+
nlf = isinplace(f) ? (out, u, p) -> f(out, u, u0, p, t) : (u, p) -> f(u, u0, p, t)
22+
prob = NonlinearProblem{isinplace(f)}(nlf, u0, p)
23+
sol = solve(prob, SimpleNewtonRaphson())
24+
sol, (sol.retcode != ReturnCode.Success)
25+
end
26+
27+
function DiffEqBase.solve(prob::ImplicitDiscreteProblem, alg::SimpleIDSolve;
28+
dt = 1,
29+
save_everystep = true,
30+
save_start = true,
31+
adaptive = false,
32+
dense = false,
33+
save_end = true,
34+
kwargs...)
35+
@assert !adaptive
36+
@assert !dense
37+
(initsol, initfail) = DiffEqBase.__init(prob, alg; dt)
38+
if initfail
39+
sol = DiffEqBase.build_solution(prob, alg, prob.tspan[1], u0, k = nothing,
40+
stats = nothing, calculate_error = false)
41+
return SciMLBase.solution_new_retcode(sol, ReturnCode.InitialFailure)
42+
end
43+
44+
u0 = initsol.u
45+
tspan = prob.tspan
46+
f = prob.f
47+
p = prob.p
48+
t = tspan[1]
49+
tf = prob.tspan[2]
50+
ts = tspan[1]:dt:tspan[2]
51+
52+
l = save_everystep ? length(ts) - 1 : 1
53+
save_start && (l = l + 1)
54+
us = Vector{typeof(u0)}(undef, l)
55+
56+
if save_start
57+
us[1] = u0
58+
end
59+
60+
u = u0
61+
convfail = false
62+
for i in 2:length(ts)
63+
uprev = u
64+
t = ts[i]
65+
nlf = isinplace(f) ? (out, u, p) -> f(out, u, uprev, p, t) :
66+
(u, p) -> f(u, uprev, p, t)
67+
nlprob = NonlinearProblem{isinplace(f)}(nlf, uprev, p)
68+
nlsol = solve(nlprob, SimpleNewtonRaphson())
69+
u = nlsol.u
70+
save_everystep && (us[i] = u)
71+
convfail = (nlsol.retcode != ReturnCode.Success)
72+
73+
if convfail
74+
sol = DiffEqBase.build_solution(prob, alg, ts[1:i], us[1:i], k = nothing,
75+
stats = nothing, calculate_error = false)
76+
sol = SciMLBase.solution_new_retcode(sol, ReturnCode.ConvergenceFailure)
77+
return sol
78+
end
79+
end
80+
81+
!save_everystep && save_end && (us[end] = u)
82+
sol = DiffEqBase.build_solution(prob, alg, ts, us,
83+
k = nothing, stats = nothing,
84+
calculate_error = false)
85+
86+
DiffEqBase.has_analytic(prob.f) &&
87+
DiffEqBase.calculate_solution_errors!(
88+
sol; timeseries_errors = true, dense_errors = false)
89+
sol
90+
end
91+
92+
export SimpleIDSolve
93+
94+
end
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
using Test
2+
using SimpleImplicitDiscreteSolve
3+
using OrdinaryDiffEqSDIRK
4+
5+
# Test implicit Euler using ImplicitDiscreteProblem
6+
@testset "Implicit Euler" begin
7+
function lotkavolterra(u, p, t)
8+
[1.5 * u[1] - u[1] * u[2], -3.0 * u[2] + u[1] * u[2]]
9+
end
10+
11+
function f!(resid, u_next, u, p, t)
12+
lv = lotkavolterra(u_next, p, t)
13+
resid[1] = u_next[1] - u[1] - 0.01 * lv[1]
14+
resid[2] = u_next[2] - u[2] - 0.01 * lv[2]
15+
nothing
16+
end
17+
u0 = [1.0, 1.0]
18+
tspan = (0.0, 0.5)
19+
20+
idprob = ImplicitDiscreteProblem(f!, u0, tspan, [])
21+
idsol = solve(idprob, SimpleIDSolve(), dt = 0.01)
22+
23+
oprob = ODEProblem(lotkavolterra, u0, tspan)
24+
osol = solve(oprob, ImplicitEuler())
25+
26+
@test isapprox(idsol[end - 1], osol[end], atol = 0.1)
27+
28+
### free-fall
29+
# y, dy
30+
function ff(u, p, t)
31+
[u[2], -9.8]
32+
end
33+
34+
function g!(resid, u_next, u, p, t)
35+
f = ff(u_next, p, t)
36+
resid[1] = u_next[1] - u[1] - 0.01 * f[1]
37+
resid[2] = u_next[2] - u[2] - 0.01 * f[2]
38+
nothing
39+
end
40+
u0 = [10.0, 0.0]
41+
tspan = (0, 0.5)
42+
43+
idprob = ImplicitDiscreteProblem(g!, u0, tspan, [])
44+
idsol = solve(idprob, SimpleIDSolve(); dt = 0.01)
45+
46+
oprob = ODEProblem(ff, u0, tspan)
47+
osol = solve(oprob, ImplicitEuler())
48+
49+
@test isapprox(idsol[end - 1], osol[end], atol = 0.1)
50+
end
51+
52+
@testset "Solver initializes" begin
53+
function periodic!(resid, u_next, u, p, t)
54+
resid[1] = u_next[1] - u[1] - sin(t * π / 4)
55+
resid[2] = 16 - u_next[2]^2 - u_next[1]^2
56+
end
57+
58+
tsteps = 15
59+
u0 = [1.0, 3.0]
60+
idprob = ImplicitDiscreteProblem(periodic!, u0, (0, tsteps), [])
61+
initsol, initfail = DiffEqBase.__init(idprob, SimpleIDSolve())
62+
@test initsol.u[1]^2 + initsol.u[2]^2 16
63+
64+
idsol = solve(idprob, SimpleIDSolve())
65+
66+
for ts in 1:tsteps
67+
step = idsol.u[ts]
68+
@test step[1]^2 + step[2]^2 16
69+
end
70+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
#Start Test Script
2727

2828
@time begin
29-
if contains(GROUP, "OrdinaryDiffEq") || GROUP == "ImplicitDiscreteSolve"
29+
if contains(GROUP, "OrdinaryDiffEq") || GROUP == "ImplicitDiscreteSolve" || GROUP == "SimpleImplicitDiscreteSolve"
3030
Pkg.develop(path = "../lib/$GROUP")
3131
Pkg.test(GROUP)
3232
elseif GROUP == "All" || GROUP == "InterfaceI" || GROUP == "Interface"

0 commit comments

Comments
 (0)