Skip to content

Commit 8a4dbcd

Browse files
Merge pull request #2626 from vyudu/ids
feat: ImplicitDiscreteSolve
2 parents f1a375a + 8eaf365 commit 8a4dbcd

File tree

8 files changed

+240
-1
lines changed

8 files changed

+240
-1
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ jobs:
5656
- OrdinaryDiffEqTsit5
5757
- OrdinaryDiffEqVerner
5858

59+
- ImplicitDiscreteSolve
60+
5961
version:
6062
- 'lts'
6163
- '1'
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name = "ImplicitDiscreteSolve"
2+
uuid = "3263718b-31ed-49cf-8a0f-35a466e8af96"
3+
authors = ["vyudu <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8+
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
9+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
10+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
11+
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
12+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
13+
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
14+
15+
[compat]
16+
DiffEqBase = "6.164.1"
17+
OrdinaryDiffEqCore = "1.18.1"
18+
OrdinaryDiffEqSDIRK = "1.2.0"
19+
Reexport = "1.2.2"
20+
SciMLBase = "2.74.1"
21+
SimpleNonlinearSolve = "2.1.0"
22+
SymbolicIndexingInterface = "0.3.38"
23+
Test = "1.10.0"
24+
UnPack = "1.0.2"
25+
26+
[extras]
27+
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
28+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
29+
30+
[targets]
31+
test = ["OrdinaryDiffEqSDIRK", "Test"]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module ImplicitDiscreteSolve
2+
3+
using SciMLBase
4+
using SimpleNonlinearSolve
5+
using UnPack
6+
using SymbolicIndexingInterface: parameter_symbols
7+
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, alg_cache, OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache, get_fsalfirstlast, isfsal, initialize!, perform_step!, isdiscretecache, isdiscretealg, alg_order, beta2_default, beta1_default, dt_required, _initialize_dae!, DefaultInit, BrownFullBasicInit, OverrideInit
8+
9+
using Reexport
10+
@reexport using DiffEqBase
11+
12+
"""
13+
SimpleIDSolve()
14+
15+
Simple solver for `ImplicitDiscreteSystems`. Uses `SimpleNewtonRaphson` to solve for the next state at every timestep.
16+
"""
17+
struct SimpleIDSolve <: OrdinaryDiffEqAlgorithm end
18+
19+
include("cache.jl")
20+
include("solve.jl")
21+
include("alg_utils.jl")
22+
23+
export SimpleIDSolve
24+
25+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
function SciMLBase.isautodifferentiable(alg::SimpleIDSolve)
2+
true
3+
end
4+
function SciMLBase.allows_arbitrary_number_types(alg::SimpleIDSolve)
5+
true
6+
end
7+
function SciMLBase.allowscomplex(alg::SimpleIDSolve)
8+
true
9+
end
10+
11+
SciMLBase.isdiscrete(alg::SimpleIDSolve) = true
12+
13+
isfsal(alg::SimpleIDSolve) = false
14+
alg_order(alg::SimpleIDSolve) = 0
15+
beta2_default(alg::SimpleIDSolve) = 0
16+
beta1_default(alg::SimpleIDSolve, beta2) = 0
17+
18+
dt_required(alg::SimpleIDSolve) = false
19+
isdiscretealg(alg::SimpleIDSolve) = true
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
mutable struct ImplicitDiscreteState{uType, pType, tType}
2+
u::Vector{uType}
3+
p::pType
4+
t_next::tType
5+
end
6+
7+
mutable struct SimpleIDSolveCache{uType} <: OrdinaryDiffEqMutableCache
8+
u::uType
9+
uprev::uType
10+
state::ImplicitDiscreteState
11+
prob::Union{Nothing, SciMLBase.AbstractNonlinearProblem}
12+
end
13+
14+
function alg_cache(alg::SimpleIDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
15+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
16+
dt, reltol, p, calck,
17+
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
18+
19+
state = ImplicitDiscreteState(similar(u), p, t)
20+
SimpleIDSolveCache(u, uprev, state, nothing)
21+
end
22+
23+
isdiscretecache(cache::SimpleIDSolveCache) = true
24+
25+
struct SimpleIDSolveConstantCache <: OrdinaryDiffEqConstantCache
26+
prob::Union{Nothing, SciMLBase.AbstractNonlinearProblem}
27+
end
28+
29+
function alg_cache(alg::SimpleIDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
30+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
31+
dt, reltol, p, calck,
32+
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
33+
34+
state = ImplicitDiscreteState(similar(u), p, t)
35+
SimpleIDSolveCache(u, uprev, state, nothing)
36+
end
37+
38+
get_fsalfirstlast(cache::SimpleIDSolveCache, rate_prototype) = (nothing, nothing)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Remake the nonlinear problem, then update
2+
function perform_step!(integrator, cache::SimpleIDSolveCache, repeat_step = false)
3+
@unpack alg, u, uprev, dt, t, f, p = integrator
4+
@unpack state, prob = cache
5+
state.u .= uprev
6+
state.t_next = t
7+
prob = remake(prob, p = state)
8+
9+
u = solve(prob, SimpleNewtonRaphson())
10+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, u.retcode)
11+
integrator.u = u
12+
end
13+
14+
function initialize!(integrator, cache::SimpleIDSolveCache)
15+
cache.state.u .= integrator.u
16+
cache.state.p = integrator.p
17+
cache.state.t_next = integrator.t
18+
f = integrator.f
19+
20+
_f = if isinplace(f)
21+
(resid, u_next, p) -> f(resid, u_next, p.u, p.p, p.t_next)
22+
else
23+
(u_next, p) -> f(u_next, p.u, p.p, p.t_next)
24+
end
25+
26+
prob = if isinplace(f)
27+
NonlinearProblem{true}(_f, cache.state.u, cache.state)
28+
else
29+
NonlinearProblem{false}(_f, cache.state.u, cache.state)
30+
end
31+
cache.prob = prob
32+
end
33+
34+
function _initialize_dae!(integrator, prob::ImplicitDiscreteProblem,
35+
alg::DefaultInit, x::Union{Val{true}, Val{false}})
36+
atol = one(eltype(prob.u0)) * 1e-12
37+
if SciMLBase.has_initializeprob(prob.f)
38+
_initialize_dae!(integrator, prob,
39+
OverrideInit(atol), x)
40+
else
41+
@unpack u, p, t, f = integrator
42+
initstate = ImplicitDiscreteState(u, p, t)
43+
44+
_f = if isinplace(f)
45+
(resid, u_next, p) -> f(resid, u_next, p.u, p.p, p.t_next)
46+
else
47+
(u_next, p) -> f(u_next, p.u, p.p, p.t_next)
48+
end
49+
prob = NonlinearProblem{isinplace(f)}(_f, u, initstate)
50+
sol = solve(prob, SimpleNewtonRaphson())
51+
integrator.u = sol
52+
end
53+
end
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#runtests
2+
using Test
3+
using ImplicitDiscreteSolve
4+
using OrdinaryDiffEqCore
5+
using OrdinaryDiffEqSDIRK
6+
7+
# Test implicit Euler using ImplicitDiscreteProblem
8+
@testset "Implicit Euler" begin
9+
function lotkavolterra(u, p, t)
10+
[1.5*u[1] - u[1]*u[2], -3.0*u[2] + u[1]*u[2]]
11+
end
12+
13+
function f!(resid, u_next, u, p, t)
14+
lv = lotkavolterra(u_next, p, t)
15+
resid[1] = u_next[1] - u[1] - 0.01*lv[1]
16+
resid[2] = u_next[2] - u[2] - 0.01*lv[2]
17+
nothing
18+
end
19+
u0 = [1., 1.]
20+
tspan = (0., 0.5)
21+
22+
idprob = ImplicitDiscreteProblem(f!, u0, tspan, []; dt = 0.01)
23+
idsol = solve(idprob, SimpleIDSolve())
24+
25+
oprob = ODEProblem(lotkavolterra, u0, tspan)
26+
osol = solve(oprob, ImplicitEuler())
27+
28+
@test isapprox(idsol[end], osol[end], atol = 0.1)
29+
30+
### free-fall
31+
# y, dy
32+
function ff(u, p, t)
33+
[u[2], -9.8]
34+
end
35+
36+
function g!(resid, u_next, u, p, t)
37+
f = ff(u_next, p, t)
38+
resid[1] = u_next[1] - u[1] - 0.01*f[1]
39+
resid[2] = u_next[2] - u[2] - 0.01*f[2]
40+
nothing
41+
end
42+
u0 = [10., 0.]
43+
tspan = (0, 0.2)
44+
45+
idprob = ImplicitDiscreteProblem(g!, u0, tspan, []; dt = 0.01)
46+
idsol = solve(idprob, SimpleIDSolve())
47+
48+
oprob = ODEProblem(ff, u0, tspan)
49+
osol = solve(oprob, ImplicitEuler())
50+
51+
@test isapprox(idsol[end], osol[end], atol = 0.1)
52+
end
53+
54+
@testset "Solver initializes" begin
55+
function periodic!(resid, u_next, u, p, t)
56+
resid[1] = u_next[1] - u[1] - sin(t*π/4)
57+
resid[2] = 16 - u_next[2]^2 - u_next[1]^2
58+
end
59+
60+
tsteps = 15
61+
u0 = [1., 3.]
62+
idprob = ImplicitDiscreteProblem(periodic!, u0, (0, tsteps), [])
63+
integ = init(idprob, SimpleIDSolve())
64+
@test integ.u[1]^2 + integ.u[2]^2 16
65+
66+
for ts in 1:tsteps
67+
step!(integ)
68+
@show integ.u
69+
@test integ.u[1]^2 + integ.u[2]^2 16
70+
end
71+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
#Start Test Script
2727

2828
@time begin
29-
if contains(GROUP, "OrdinaryDiffEq")
29+
if contains(GROUP, "OrdinaryDiffEq") || GROUP == "ImplicitDiscreteSolve"
3030
Pkg.develop(path = "../lib/$GROUP")
3131
Pkg.test(GROUP)
3232
elseif GROUP == "All" || GROUP == "InterfaceI" || GROUP == "Interface"

0 commit comments

Comments
 (0)