Skip to content

Commit 7c83d1c

Browse files
committed
begin implementing NLSolve
1 parent 827e64e commit 7c83d1c

File tree

5 files changed

+149
-14
lines changed

5 files changed

+149
-14
lines changed

lib/ImplicitDiscreteSolve/Project.toml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,37 @@ name = "ImplicitDiscreteSolve"
22
uuid = "3263718b-31ed-49cf-8a0f-35a466e8af96"
33
authors = ["vyudu <[email protected]>"]
44
version = "0.1.0"
5+
6+
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
9+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
10+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
11+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
13+
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
14+
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
15+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
16+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
17+
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
18+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
19+
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
20+
21+
[compat]
22+
ADTypes = "1.13.0"
23+
CommonSolve = "0.2.4"
24+
DiffEqBase = "6.164.1"
25+
DifferentiationInterface = "0.6.42"
26+
LinearAlgebra = "1.11.0"
27+
NonlinearSolveBase = "1.5.0"
28+
OrdinaryDiffEqCore = "1.18.1"
29+
OrdinaryDiffEqSDIRK = "1.2.0"
30+
Reexport = "1.2.2"
31+
SciMLBase = "2.74.1"
32+
SimpleNonlinearSolve = "2.1.0"
33+
SymbolicIndexingInterface = "0.3.37"
34+
UnPack = "1.0.2"
35+
36+
[extras]
37+
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
38+
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"

lib/ImplicitDiscreteSolve/src/ImplicitDiscreteSolve.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@ using NonlinearSolveBase
66
using SymbolicIndexingInterface
77
using LinearAlgebra
88
using ADTypes
9-
using TaylorDiff
10-
using DocStringExtensions
9+
using UnPack
10+
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, alg_cache, OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache, get_fsalfirstlast, initialize!, perform_step!
1111
import CommonSolve
1212
import DifferentiationInterface as DI
1313

14-
using ConcreteStructs: @concrete
14+
using Reexport
15+
@reexport using DiffEqBase
1516

1617
"""
17-
IteratedNonlinearSolve(; nlsolvealg, autodiff = true, kwargs...)
18+
IDSolve(alg; autodiff = true, kwargs...)
1819
19-
This algorithm is a solver for ImplicitDiscreteSystems.
20+
Solver for `ImplicitDiscreteSystems`. `alg` is the NonlinearSolve algorithm that is used to solve for the next timestep at each step.
2021
"""
21-
@concrete struct IteratedNonlinearSolve <: NonlinearSolveBase.AbstractNonlinearSolveAlgorithm
22-
nlsolvealg
23-
autodiff
24-
kwargs
22+
struct IDSolve{algType} <: OrdinaryDiffEqAlgorithm
23+
nlsolve::algType
2524
end
2625

27-
export IteratedNonlinearSolve
28-
26+
include("cache.jl")
2927
include("solve.jl")
3028

29+
export IDSolve
30+
3131
end
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
mutable struct ImplicitDiscreteState{uType, pType, tType}
2+
u::Vector{uType}
3+
p::Union{Nothing, Vector{pType}}
4+
t_next::tType
5+
end
6+
7+
mutable struct IDSolveCache{uType} <: OrdinaryDiffEqMutableCache
8+
u::uType
9+
uprev::uType
10+
state::ImplicitDiscreteState
11+
prob::Union{Nothing, AbstractNonlinearProblem}
12+
end
13+
14+
function alg_cache(alg::IDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
15+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
16+
dt, reltol, p, calck,
17+
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
18+
19+
state = ImplicitDiscreteState(similar(u), similar(p), t)
20+
IDSolveCache(u, uprev, state, nothing)
21+
end
22+
23+
isdiscretecache(cache::IDSolveCache) = true
24+
25+
function alg_cache(alg::IDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
26+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
27+
dt, reltol, p, calck,
28+
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
29+
30+
state = ImplicitDiscreteState(similar(u), similar(p), t)
31+
IDSolveCache(u, uprev, state, nothing)
32+
end
33+
34+
isfsal(alg::IDSolve) = false
35+
get_fsalfirstlast(cache::IDSolveCache, rate_prototype) = (nothing, nothing)
Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,46 @@
1+
# Remake the nonlinear problem, then update
2+
function perform_step!(integrator, cache::IDSolveCache, repeat_step = false)
3+
@unpack alg, u, uprev, dt, t, f, p = integrator
4+
nlsolve = alg.nlsolve
5+
@unpack state, prob = cache
6+
state.u .= uprev
7+
state.t_next = t
8+
@show state
9+
prob = remake(prob, p = state)
110

2-
# make a nonlinear problem and solve at every timestep.
11+
u = solve(prob, nlsolve)
12+
any(isnan, u) && (integrator.sol.retcode = SciMLBase.ReturnCode.Failure)
13+
integrator.u = u
14+
end
15+
16+
function initialize!(integrator, cache::IDSolveCache)
17+
cache.state.u .= integrator.u
18+
cache.state.p .= integrator.p
19+
cache.state.t_next = integrator.t
20+
f = integrator.f
321

4-
function CommonSolve.solve(prob::ImplicitDiscreteProblem)
5-
22+
_f = if isinplace(f)
23+
(resid, u_next, p) -> f(resid, u_next, p.u, p.p, p.t_next)
24+
else
25+
(u_next, p) -> f(u_next, p.u, p.p, p.t_next)
26+
end
27+
28+
prob = if isinplace(f)
29+
NonlinearProblem{true}(_f, cache.state.u, cache.state)
30+
else
31+
NonlinearProblem{false}(_f, cache.state.u, cache.state)
32+
end
33+
cache.prob = prob
634
end
35+
36+
#### unnecessary
37+
#function DiffEqBase.__init(prob::ImplicitDiscreteProblem, alg)
38+
# f = prob.f
39+
# t_i = prob.tspan[1]
40+
# u0 = state_values(prob)
41+
# p = parameter_values(prob)
42+
#
43+
# _f(resid, u_next, p) = f(resid, u_next, p.u, p.p, p.t)
44+
# state = ImplicitDiscreteState(u0, p, t_i)
45+
# nlprob = NonlinearProblem(_f, u0, state)
46+
#end
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#runtests
2+
using ImplicitDiscreteSolve
3+
using OrdinaryDiffEqCore
4+
using OrdinaryDiffEqSDIRK
5+
using SimpleNonlinearSolve
6+
7+
# Test implicit Euler using ImplicitDiscreteProblem
8+
@testset "Implicit Discrete System" begin
9+
function lotkavolterra(u, p, t)
10+
[1.5*u[1] - u[1]*u[2], -3.0*u[2] + u[1]*u[2]]
11+
end
12+
13+
function f!(resid, u_next, u, p, t)
14+
@. resid = u_next - 0.01*lotkavolterra(u_next, p, t)
15+
end
16+
u0 = [1., 1.]
17+
tspan = (0., 1.)
18+
19+
idprob = ImplicitDiscreteProblem(f!, u0, tspan, []; dt = 0.01)
20+
idsol = solve(idprob, IDSolve(SimpleNewtonRaphson()))
21+
22+
oprob = ODEProblem(lotkavolterra, u0, tspan)
23+
osol = solve(oprob, ImplicitEuler())
24+
25+
@test idsol[end] osol[end]
26+
end

0 commit comments

Comments
 (0)