Skip to content

Commit ec56a82

Browse files
committed
Add initialization
1 parent b3bd14f commit ec56a82

File tree

4 files changed

+74
-21
lines changed

4 files changed

+74
-21
lines changed

lib/SimpleImplicitDiscreteSolve/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ Reexport = "1.2.2"
2020
SciMLBase = "2.74.1"
2121
SimpleNonlinearSolve = "2.1.0"
2222
SymbolicIndexingInterface = "0.3.38"
23+
Test = "1.11.0"
2324
UnPack = "1.0.2"
2425

2526
[extras]
2627
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
27-
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
28+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2829

2930
[targets]
30-
test = ["OrdinaryDiffEqSDIRK", "SimpleNonlinearSolve"]
31+
test = ["OrdinaryDiffEqSDIRK", "Test"]

lib/SimpleImplicitDiscreteSolve/src/alg_utils.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,3 @@ beta1_default(alg::SimpleIDSolve, beta2) = 0
1717

1818
dt_required(alg::SimpleIDSolve) = false
1919
isdiscretealg(alg::SimpleIDSolve) = true
20-
21-
function _initialize_dae!(integrator, prob::ImplicitDiscreteProblem,
22-
alg::DefaultInit, x::Union{Val{true}, Val{false}})
23-
atol = one(eltype(prob.u0)) * 1e-12
24-
if SciMLBase.has_initializeprob(prob.f)
25-
_initialize_dae!(integrator, prob,
26-
OverrideInit(atol), x)
27-
elseif !applicable(_initialize_dae!, integrator, prob,
28-
BrownFullBasicInit(atol), x)
29-
error("`OrdinaryDiffEqNonlinearSolve` is not loaded, which is required for the default initialization algorithm (`BrownFullBasicInit` or `ShampineCollocationInit`). To solve this problem, either do `using OrdinaryDiffEqNonlinearSolve` or pass `initializealg = CheckInit()` to the `solve` function. This second option requires consistent `u0`.")
30-
else
31-
_initialize_dae!(integrator, prob,
32-
BrownFullBasicInit(atol), x)
33-
end
34-
end

lib/SimpleImplicitDiscreteSolve/src/solve.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,52 @@ function initialize!(integrator, cache::SimpleIDSolveCache)
3030
end
3131
cache.prob = prob
3232
end
33+
34+
function _initialize_dae!(integrator, prob::ImplicitDiscreteProblem,
35+
alg::DefaultInit, x::Union{Val{true}, Val{false}})
36+
atol = one(eltype(prob.u0)) * 1e-12
37+
if SciMLBase.has_initializeprob(prob.f)
38+
_initialize_dae!(integrator, prob,
39+
OverrideInit(atol), x)
40+
else
41+
@unpack u, p, t, f = integrator
42+
initstate = ImplicitDiscreteState(u, p, t)
43+
44+
_f = if isinplace(f)
45+
(resid, u_next, p) -> f(resid, u_next, p.u, p.p, p.t_next)
46+
else
47+
(u_next, p) -> f(u_next, p.u, p.p, p.t_next)
48+
end
49+
prob = NonlinearProblem{isinplace(f)}(_f, u, initstate)
50+
sol = solve(prob, SimpleNewtonRaphson())
51+
integrator.u = sol
52+
end
53+
end
54+
55+
#### TODO: Implement real algorithm
56+
function _initialize_dae!(integrator, prob::ImplicitDiscreteProblem, alg::BrownFullBasicInit, isinplace::Val{true})
57+
@unpack p, t, f = integrator
58+
u0 = integrator.u
59+
60+
nlequation! = (out, u, p) -> begin
61+
f(out, u, u0, p, t)
62+
end
63+
64+
nlfunc = NonlinearFunction(nlequation!; jac_prototype = f.jac_prototype)
65+
nlprob = NonlinearProblem(nlfunc, ifelse.(differential_vars, du, u), p)
66+
nlsol = solve(nlprob, nlsolve; abstol = alg.abstol, reltol = integrator.opts.reltol)
67+
68+
@. du = ifelse(differential_vars, nlsol.u, du)
69+
@. u = ifelse(differential_vars, u, nlsol.u)
70+
71+
recursivecopy!(integrator.uprev, integrator.u)
72+
if alg_extrapolates(integrator.alg)
73+
recursivecopy!(integrator.uprev2, integrator.uprev)
74+
end
75+
76+
if nlsol.retcode != ReturnCode.Success
77+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
78+
ReturnCode.InitialFailure)
79+
end
80+
return
81+
end

lib/SimpleImplicitDiscreteSolve/test/runtests.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#runtests
2+
using Test
23
using SimpleImplicitDiscreteSolve
34
using OrdinaryDiffEqCore
45
using OrdinaryDiffEqSDIRK
@@ -24,7 +25,7 @@ using OrdinaryDiffEqSDIRK
2425
oprob = ODEProblem(lotkavolterra, u0, tspan)
2526
osol = solve(oprob, ImplicitEuler())
2627

27-
@test isapprox(idsol[end], osol[end], atol = 0.01)
28+
@test isapprox(idsol[end], osol[end], atol = 0.1)
2829

2930
### free-fall
3031
# y, dy
@@ -38,16 +39,33 @@ using OrdinaryDiffEqSDIRK
3839
resid[2] = u_next[2] - u[2] - 0.01*f[2]
3940
nothing
4041
end
41-
u0 = [100., 3.]
42+
u0 = [10., 0.]
43+
tspan = (0, 0.2)
4244

4345
idprob = ImplicitDiscreteProblem(g!, u0, tspan, []; dt = 0.01)
4446
idsol = solve(idprob, SimpleIDSolve())
4547

4648
oprob = ODEProblem(ff, u0, tspan)
4749
osol = solve(oprob, ImplicitEuler())
4850

49-
@test isapprox(idsol[end], osol[end], atol = 0.01)
51+
@test isapprox(idsol[end], osol[end], atol = 0.1)
5052
end
5153

52-
@testset "Solve respects initialization" begin
54+
@testset "Solver initializes" begin
55+
function periodic!(resid, u_next, u, p, t)
56+
resid[1] = u_next[1] - u[1] - sin(t*π/4)
57+
resid[2] = 16 - u_next[2]^2 - u_next[1]^2
58+
end
59+
60+
tsteps = 15
61+
u0 = [1., 3.]
62+
idprob = ImplicitDiscreteProblem(periodic!, u0, (0, tsteps), [])
63+
integ = init(idprob, SimpleIDSolve())
64+
@test integ.u[1]^2 + integ.u[2]^2 16
65+
66+
for ts in 1:tsteps
67+
step!(integ)
68+
@show integ.u
69+
@test integ.u[1]^2 + integ.u[2]^2 16
70+
end
5371
end

0 commit comments

Comments
 (0)