Skip to content

Commit 676403e

Browse files
test: test new initialization features
1 parent c9cb8f8 commit 676403e

File tree

2 files changed

+159
-0
lines changed

2 files changed

+159
-0
lines changed

test/initialization.jl

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test
2+
3+
@testset "CheckInit" begin
4+
@testset "ODEProblem" begin
5+
function rhs(u, p, t)
6+
return [u[1] * t, u[1]^2 - u[2]^2]
7+
end
8+
function rhs!(du, u, p, t)
9+
du[1] = u[1] * t
10+
du[2] = u[1]^2 - u[2]^2
11+
end
12+
13+
oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0])
14+
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0])
15+
16+
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
17+
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
18+
integ = init(prob)
19+
u0, _, success = SciMLBase.get_initial_values(
20+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
21+
@test success
22+
@test u0 == prob.u0
23+
24+
integ.u[2] = 2.0
25+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
26+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
27+
end
28+
end
29+
30+
@testset "DAEProblem" begin
31+
function daerhs(du, u, p, t)
32+
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2]
33+
end
34+
function daerhs!(resid, du, u, p, t)
35+
resid[1] = du[1] - u[1] * t - p
36+
resid[2] = u[1]^2 - u[2]^2
37+
end
38+
39+
oopfn = DAEFunction{false}(daerhs)
40+
iipfn = DAEFunction{true}(daerhs!)
41+
42+
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
43+
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
44+
integ = init(prob, DImplicitEuler())
45+
u0, _, success = SciMLBase.get_initial_values(
46+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
47+
@test success
48+
@test u0 == prob.u0
49+
50+
integ.u[2] = 2.0
51+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
52+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
53+
54+
integ.u[2] = 1.0
55+
integ.du[1] = 2.0
56+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
57+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
58+
end
59+
end
60+
end
61+
62+
@testset "OverrideInit" begin
63+
function rhs2(u, p, t)
64+
return [u[1] * t + p, u[1]^2 - u[2]^2]
65+
end
66+
67+
@testset "No-op without `initialization_data`" begin
68+
prob = ODEProblem(rhs2, [1.0, 2.0], (0.0, 1.0), 1.0)
69+
integ = init(prob)
70+
integ.u[2] = 3.0
71+
u0, p, success = SciMLBase.get_initial_values(
72+
prob, integ, prob.f, SciMLBase.OverrideInit(), Val(false))
73+
@test u0 [1.0, 3.0]
74+
@test success
75+
end
76+
77+
# unknowns are u[2], p. Parameter is u[1]
78+
initprob = NonlinearProblem([1.0, 1.0], [1.0]) do x, _u1
79+
u2, p = x
80+
u1 = _u1[1]
81+
return [u1^2 - u2^2, p^2 - 2p + 1]
82+
end
83+
update_initializeprob! = function (iprob, integ)
84+
iprob.p[1] = integ.u[1]
85+
end
86+
initprobmap = function (nlsol)
87+
return [parameter_values(nlsol)[1], nlsol.u[1]]
88+
end
89+
initprobpmap = function (nlsol)
90+
return nlsol.u[2]
91+
end
92+
initialization_data = SciMLBase.OverrideInitData(
93+
initprob, update_initializeprob!, initprobmap, initprobpmap)
94+
fn = ODEFunction(rhs2; initialization_data)
95+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
96+
integ = init(prob; initializealg = NoInit())
97+
98+
@testset "Errors without `nlsolve_alg`" begin
99+
@test_throws SciMLBase.OverrideInitMissingAlgorithm SciMLBase.get_initial_values(
100+
prob, integ, fn, SciMLBase.OverrideInit(), Val(false))
101+
end
102+
103+
@testset "Solves" begin
104+
u0, p, success = SciMLBase.get_initial_values(
105+
prob, integ, fn, SciMLBase.OverrideInit(),
106+
Val(false); nlsolve_alg = NewtonRaphson())
107+
108+
@test u0 [2.0, 2.0]
109+
@test p 1.0
110+
@test success
111+
112+
initprob.p[1] = 1.0
113+
end
114+
115+
@testset "Solves with non-integrator value provider" begin
116+
_integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t)
117+
u0, p, success = SciMLBase.get_initial_values(
118+
prob, _integ, fn, SciMLBase.OverrideInit(),
119+
Val(false); nlsolve_alg = NewtonRaphson())
120+
121+
@test u0 [2.0, 2.0]
122+
@test p 1.0
123+
@test success
124+
125+
initprob.p[1] = 1.0
126+
end
127+
128+
@testset "Solves without `update_initializeprob!`" begin
129+
initdata = SciMLBase.@set initialization_data.update_initializeprob! = nothing
130+
fn = ODEFunction(rhs2; initialization_data = initdata)
131+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
132+
integ = init(prob; initializealg = NoInit())
133+
134+
u0, p, success = SciMLBase.get_initial_values(
135+
prob, integ, fn, SciMLBase.OverrideInit(),
136+
Val(false); nlsolve_alg = NewtonRaphson())
137+
@test u0 [1.0, 1.0]
138+
@test p 1.0
139+
@test success
140+
end
141+
142+
@testset "Solves without `initializeprobpmap`" begin
143+
initdata = SciMLBase.@set initialization_data.initializeprobpmap = nothing
144+
fn = ODEFunction(rhs2; initialization_data = initdata)
145+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
146+
integ = init(prob; initializealg = NoInit())
147+
148+
u0, p, success = SciMLBase.get_initial_values(
149+
prob, integ, fn, SciMLBase.OverrideInit(),
150+
Val(false); nlsolve_alg = NewtonRaphson())
151+
152+
@test u0 [2.0, 2.0]
153+
@test p 0.0
154+
@test success
155+
end
156+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ end
6363
@time @safetestset "Serialization tests" begin
6464
include("serialization_tests.jl")
6565
end
66+
@time @safetestset "Initialization" begin
67+
include("initialization.jl")
68+
end
6669
end
6770

6871
if !is_APPVEYOR &&

0 commit comments

Comments
 (0)