Skip to content

Commit a12bd70

Browse files
committed
feat: SimpleImplicitDiscreteSolve
1 parent 7a4a5fb commit a12bd70

File tree

4 files changed

+215
-0
lines changed

4 files changed

+215
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 SciML
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name = "SimpleImplicitDiscreteSolve"
2+
uuid = "8b67ef88-54bd-43ff-aca0-8be02588656a"
3+
authors = ["vyudu <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
9+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
10+
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
11+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
12+
13+
[compat]
14+
DiffEqBase = "6.164.1"
15+
OrdinaryDiffEqSDIRK = "1.2.0"
16+
Reexport = "1.2.2"
17+
SciMLBase = "2.74.1"
18+
SimpleNonlinearSolve = "2.1.0"
19+
StaticArrays = "1.9.13"
20+
Test = "1.10"
21+
22+
[extras]
23+
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
24+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
25+
26+
[targets]
27+
test = ["OrdinaryDiffEqSDIRK", "Test"]
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
module SimpleImplicitDiscreteSolve
2+
3+
using SciMLBase
4+
using SimpleNonlinearSolve
5+
using Reexport
6+
using StaticArrays
7+
@reexport using DiffEqBase
8+
9+
"""
10+
SimpleIDSolve()
11+
12+
Simple solver for `ImplicitDiscreteSystems`. Uses `SimpleNewtonRaphson` to solve for the next state at every timestep.
13+
"""
14+
struct SimpleIDSolve <: SciMLBase.AbstractODEAlgorithm end
15+
16+
function DiffEqBase.__init(prob::ImplicitDiscreteProblem, alg::SimpleIDSolve; dt = 1)
17+
u0 = prob.u0
18+
p = prob.p
19+
f = prob.f
20+
t = prob.tspan[1]
21+
22+
nlf = isinplace(f) ? (out, u, p) -> f(out, u, u0, p, t) : (u, p) -> f(u, u0, p, t)
23+
prob = NonlinearProblem{isinplace(f)}(nlf, u0, p)
24+
sol = solve(prob, SimpleNewtonRaphson())
25+
sol, (sol.retcode != ReturnCode.Success)
26+
end
27+
28+
function DiffEqBase.solve(prob::ImplicitDiscreteProblem, alg::SimpleIDSolve;
29+
dt = 1,
30+
save_everystep = true,
31+
save_start = true,
32+
adaptive = false,
33+
dense = false,
34+
save_end = true,
35+
kwargs...)
36+
@assert !adaptive
37+
@assert !dense
38+
(initsol, initfail) = DiffEqBase.__init(prob, alg; dt)
39+
if initfail
40+
sol = DiffEqBase.build_solution(prob, alg, prob.tspan[1], u0, k = nothing,
41+
stats = nothing, calculate_error = false)
42+
return SciMLBase.solution_new_retcode(sol, ReturnCode.InitialFailure)
43+
end
44+
45+
u0 = initsol.u
46+
tspan = prob.tspan
47+
f = prob.f
48+
p = prob.p
49+
t = tspan[1]
50+
tf = prob.tspan[2]
51+
ts = tspan[1]:dt:tspan[2]
52+
53+
l = save_everystep ? length(ts) - 1 : 1
54+
save_start && (l = l + 1)
55+
u0type = typeof(u0)
56+
us = u0type <: StaticArray ? MVector{l, u0type}(undef) : Vector{u0type}(undef, l)
57+
58+
if save_start
59+
us[1] = u0
60+
end
61+
62+
u = u0
63+
convfail = false
64+
for i in 2:length(ts)
65+
uprev = u
66+
t = ts[i]
67+
nlf = isinplace(f) ? (out, u, p) -> f(out, u, uprev, p, t) :
68+
(u, p) -> f(u, uprev, p, t)
69+
nlprob = NonlinearProblem{isinplace(f)}(nlf, uprev, p)
70+
nlsol = solve(nlprob, SimpleNewtonRaphson())
71+
u = nlsol.u
72+
save_everystep && (us[i] = u)
73+
convfail = (nlsol.retcode != ReturnCode.Success)
74+
75+
if convfail
76+
sol = DiffEqBase.build_solution(prob, alg, ts[1:i], us[1:i], k = nothing,
77+
stats = nothing, calculate_error = false)
78+
sol = SciMLBase.solution_new_retcode(sol, ReturnCode.ConvergenceFailure)
79+
return sol
80+
end
81+
end
82+
83+
!save_everystep && save_end && (us[end] = u)
84+
sol = DiffEqBase.build_solution(prob, alg, ts, us,
85+
k = nothing, stats = nothing,
86+
calculate_error = false)
87+
88+
DiffEqBase.has_analytic(prob.f) &&
89+
DiffEqBase.calculate_solution_errors!(
90+
sol; timeseries_errors = true, dense_errors = false)
91+
sol
92+
end
93+
94+
export SimpleIDSolve
95+
96+
end
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#runtests
2+
using Test
3+
using SimpleImplicitDiscreteSolve
4+
using OrdinaryDiffEqSDIRK
5+
6+
# Test implicit Euler using ImplicitDiscreteProblem
7+
@testset "Implicit Euler" begin
8+
function lotkavolterra(u, p, t)
9+
[1.5 * u[1] - u[1] * u[2], -3.0 * u[2] + u[1] * u[2]]
10+
end
11+
12+
function f!(resid, u_next, u, p, t)
13+
lv = lotkavolterra(u_next, p, t)
14+
resid[1] = u_next[1] - u[1] - 0.01 * lv[1]
15+
resid[2] = u_next[2] - u[2] - 0.01 * lv[2]
16+
nothing
17+
end
18+
u0 = [1.0, 1.0]
19+
tspan = (0.0, 0.5)
20+
21+
idprob = ImplicitDiscreteProblem(f!, u0, tspan, [])
22+
idsol = solve(idprob, SimpleIDSolve(), dt = 0.01)
23+
24+
oprob = ODEProblem(lotkavolterra, u0, tspan)
25+
osol = solve(oprob, ImplicitEuler())
26+
27+
@test isapprox(idsol[end - 1], osol[end], atol = 0.1)
28+
29+
### free-fall
30+
# y, dy
31+
function ff(u, p, t)
32+
[u[2], -9.8]
33+
end
34+
35+
function g!(resid, u_next, u, p, t)
36+
f = ff(u_next, p, t)
37+
resid[1] = u_next[1] - u[1] - 0.01 * f[1]
38+
resid[2] = u_next[2] - u[2] - 0.01 * f[2]
39+
nothing
40+
end
41+
u0 = [10.0, 0.0]
42+
tspan = (0, 0.5)
43+
44+
idprob = ImplicitDiscreteProblem(g!, u0, tspan, [])
45+
idsol = solve(idprob, SimpleIDSolve(); dt = 0.01)
46+
47+
oprob = ODEProblem(ff, u0, tspan)
48+
osol = solve(oprob, ImplicitEuler())
49+
50+
@test isapprox(idsol[end - 1], osol[end], atol = 0.1)
51+
end
52+
53+
@testset "Solver initializes" begin
54+
function periodic!(resid, u_next, u, p, t)
55+
resid[1] = u_next[1] - u[1] - sin(t * π / 4)
56+
resid[2] = 16 - u_next[2]^2 - u_next[1]^2
57+
end
58+
59+
tsteps = 15
60+
u0 = [1.0, 3.0]
61+
idprob = ImplicitDiscreteProblem(periodic!, u0, (0, tsteps), [])
62+
initsol, initfail = DiffEqBase.__init(idprob, SimpleIDSolve())
63+
@test initsol.u[1]^2 + initsol.u[2]^2 16
64+
65+
idsol = solve(idprob, SimpleIDSolve())
66+
67+
for ts in 1:tsteps
68+
step = idsol.u[ts]
69+
@test step[1]^2 + step[2]^2 16
70+
end
71+
end

0 commit comments

Comments
 (0)