Skip to content

Commit 262019c

Browse files
Update solution_interface.jl
1 parent 9c30360 commit 262019c

File tree

1 file changed

+121
-117
lines changed

1 file changed

+121
-117
lines changed

test/downstream/solution_interface.jl

Lines changed: 121 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,29 @@ using Plots: Plots, plot
66

77
### Tests on non-layered model (everything should work). ###
88

9-
@parameters a b c d
10-
@variables s1(t) s2(t)
11-
12-
eqs = [D(s1) ~ a * s1 / (1 + s1 + s2) - b * s1,
13-
D(s2) ~ +c * s2 / (1 + s1 + s2) - d * s2]
14-
15-
@mtkcompile population_model = System(eqs, t)
16-
17-
# Tests on ODEProblem.
18-
u0 = [s1 => 2.0, s2 => 1.0]
19-
p = [a => 2.0, b => 1.0, c => 1.0, d => 1.0]
20-
tspan = (0.0, 1000000.0)
21-
oprob = ODEProblem(population_model, [u0; p], tspan)
22-
sol = solve(oprob, Rodas4())
23-
24-
@test sol[s1] == sol[population_model.s1] == sol[:s1]
25-
@test sol[s2] == sol[population_model.s2] == sol[:s2]
26-
@test sol[s1][end] 1.0
27-
@test_throws Exception sol[a]
28-
@test_throws Exception sol[population_model.a]
29-
@test_throws Exception sol[:a]
9+
@testset "Basic indexing" begin
10+
@parameters a b c d
11+
@variables s1(t) s2(t)
12+
13+
eqs = [D(s1) ~ a * s1 / (1 + s1 + s2) - b * s1,
14+
D(s2) ~ +c * s2 / (1 + s1 + s2) - d * s2]
15+
16+
@mtkcompile population_model = System(eqs, t)
17+
18+
# Tests on ODEProblem.
19+
u0 = [s1 => 2.0, s2 => 1.0]
20+
p = [a => 2.0, b => 1.0, c => 1.0, d => 1.0]
21+
tspan = (0.0, 1000000.0)
22+
oprob = ODEProblem(population_model, [u0; p], tspan)
23+
sol = solve(oprob, Rodas4())
24+
25+
@test sol[s1] == sol[population_model.s1] == sol[:s1]
26+
@test sol[s2] == sol[population_model.s2] == sol[:s2]
27+
@test sol[s1][end] 1.0
28+
@test_throws Exception sol[a]
29+
@test_throws Exception sol[population_model.a]
30+
@test_throws Exception sol[:a]
31+
end
3032

3133
@testset "plot ODE solution" begin
3234
Plots.unicodeplots()
@@ -51,102 +53,104 @@ sol = solve(oprob, Rodas4())
5153
@test_nowarn plot(sol; plot_analytic = true)
5254
end
5355

54-
# Tests on SDEProblem
55-
noiseeqs = [0.1 * s1,
56-
0.1 * s2]
57-
@named noisy_population_model = SDESystem(population_model, noiseeqs)
58-
noisy_population_model = complete(noisy_population_model)
59-
sprob = SDEProblem(noisy_population_model, [u0; p], (0.0, 100.0))
60-
sol = solve(sprob, ImplicitEM())
61-
62-
@test sol[s1] == sol[noisy_population_model.s1] == sol[:s1]
63-
@test sol[s2] == sol[noisy_population_model.s2] == sol[:s2]
64-
@test_throws Exception sol[a]
65-
@test_throws Exception sol[noisy_population_model.a]
66-
@test_throws Exception sol[:a]
67-
@test_nowarn sol(0.5, idxs = noisy_population_model.s1)
68-
### Tests on layered model (some things should not work). ###
69-
70-
@parameters σ ρ β
71-
@variables x(t) y(t) z(t)
72-
73-
eqs = [D(x) ~ σ * (y - x),
74-
D(y) ~ x *- z) - y,
75-
D(z) ~ x * y - β * z]
76-
77-
@named lorenz1 = System(eqs, t)
78-
@named lorenz2 = System(eqs, t)
79-
80-
@parameters γ
81-
@variables a(t) α(t)
82-
connections = [0 ~ lorenz1.x + lorenz2.y + a * γ,
83-
α ~ 2lorenz1.x + a * γ]
84-
@mtkcompile sys = System(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2])
85-
86-
u0 = [lorenz1.x => 1.0,
87-
lorenz1.y => 0.0,
88-
lorenz1.z => 0.0,
89-
lorenz2.x => 0.0,
90-
lorenz2.y => 1.0,
91-
lorenz2.z => 0.0]
92-
93-
p = [lorenz1.σ => 10.0,
94-
lorenz1.ρ => 28.0,
95-
lorenz1.β => 8 / 3,
96-
lorenz2.σ => 10.0,
97-
lorenz2.ρ => 28.0,
98-
lorenz2.β => 8 / 3,
99-
γ => 2.0]
100-
101-
tspan = (0.0, 100.0)
102-
prob = ODEProblem(sys, [u0; p], tspan)
103-
sol = solve(prob, Rodas4())
104-
105-
@test_throws ArgumentError sol[x]
106-
@test in(sol[lorenz1.x], [getindex.(sol.u, i) for i in 1:length(unknowns(sol.prob.f.sys))])
107-
@test_throws KeyError sol[:x]
108-
109-
### Non-symbolic indexing tests
110-
@test sol[:, 1] isa AbstractVector
111-
@test sol[:, 1:2] isa AbstractDiffEqArray
112-
@test sol[:, [1, 2]] isa AbstractDiffEqArray
113-
114-
sol1 = sol(0.0:1.0:10.0)
115-
@test sol1.u isa Vector
116-
@test first(sol1.u) isa Vector
117-
@test length(sol1.u) == 11
118-
@test length(sol1.t) == 11
119-
120-
sol2 = sol(0.1)
121-
@test sol2 isa Vector
122-
@test length(sol2) == length(unknowns(sys))
123-
@test first(sol2) isa Real
124-
125-
sol3 = sol(0.0:1.0:10.0, idxs = [lorenz1.x, lorenz2.x])
126-
127-
sol7 = sol(0.0:1.0:10.0, idxs = [2, 1])
128-
@test sol7.u isa Vector
129-
@test first(sol7.u) isa Vector
130-
@test length(sol7.u) == 11
131-
@test length(sol7.t) == 11
132-
@test collect(sol7[t]) sol3.t
133-
@test collect(sol7[t, 1:5]) sol3.t[1:5]
134-
135-
sol8 = sol(0.1, idxs = [2, 1])
136-
@test sol8 isa Vector
137-
@test length(sol8) == 2
138-
@test first(sol8) isa Real
139-
140-
sol9 = sol(0.0:1.0:10.0, idxs = 2)
141-
@test sol9.u isa Vector
142-
@test first(sol9.u) isa Real
143-
@test length(sol9.u) == 11
144-
@test length(sol9.t) == 11
145-
@test collect(sol9[t]) sol3.t
146-
@test collect(sol9[t, 1:5]) sol3.t[1:5]
147-
148-
sol10 = sol(0.1, idxs = 2)
149-
@test sol10 isa Real
56+
@testset "Symbolic Indexing" begin
57+
# Tests on SDEProblem
58+
noiseeqs = [0.1 * s1,
59+
0.1 * s2]
60+
@named noisy_population_model = SDESystem(population_model, noiseeqs)
61+
noisy_population_model = complete(noisy_population_model)
62+
sprob = SDEProblem(noisy_population_model, [u0; p], (0.0, 100.0))
63+
sol = solve(sprob, ImplicitEM())
64+
65+
@test sol[s1] == sol[noisy_population_model.s1] == sol[:s1]
66+
@test sol[s2] == sol[noisy_population_model.s2] == sol[:s2]
67+
@test_throws Exception sol[a]
68+
@test_throws Exception sol[noisy_population_model.a]
69+
@test_throws Exception sol[:a]
70+
@test_nowarn sol(0.5, idxs = noisy_population_model.s1)
71+
### Tests on layered model (some things should not work). ###
72+
73+
@parameters σ ρ β
74+
@variables x(t) y(t) z(t)
75+
76+
eqs = [D(x) ~ σ * (y - x),
77+
D(y) ~ x *- z) - y,
78+
D(z) ~ x * y - β * z]
79+
80+
@named lorenz1 = System(eqs, t)
81+
@named lorenz2 = System(eqs, t)
82+
83+
@parameters γ
84+
@variables a(t) α(t)
85+
connections = [0 ~ lorenz1.x + lorenz2.y + a * γ,
86+
α ~ 2lorenz1.x + a * γ]
87+
@mtkcompile sys = System(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2])
88+
89+
u0 = [lorenz1.x => 1.0,
90+
lorenz1.y => 0.0,
91+
lorenz1.z => 0.0,
92+
lorenz2.x => 0.0,
93+
lorenz2.y => 1.0,
94+
lorenz2.z => 0.0]
95+
96+
p = [lorenz1.σ => 10.0,
97+
lorenz1.ρ => 28.0,
98+
lorenz1.β => 8 / 3,
99+
lorenz2.σ => 10.0,
100+
lorenz2.ρ => 28.0,
101+
lorenz2.β => 8 / 3,
102+
γ => 2.0]
103+
104+
tspan = (0.0, 100.0)
105+
prob = ODEProblem(sys, [u0; p], tspan)
106+
sol = solve(prob, Rodas4())
107+
108+
@test_throws ArgumentError sol[x]
109+
@test in(sol[lorenz1.x], [getindex.(sol.u, i) for i in 1:length(unknowns(sol.prob.f.sys))])
110+
@test_throws KeyError sol[:x]
111+
112+
### Non-symbolic indexing tests
113+
@test sol[:, 1] isa AbstractVector
114+
@test sol[:, 1:2] isa AbstractDiffEqArray
115+
@test sol[:, [1, 2]] isa AbstractDiffEqArray
116+
117+
sol1 = sol(0.0:1.0:10.0)
118+
@test sol1.u isa Vector
119+
@test first(sol1.u) isa Vector
120+
@test length(sol1.u) == 11
121+
@test length(sol1.t) == 11
122+
123+
sol2 = sol(0.1)
124+
@test sol2 isa Vector
125+
@test length(sol2) == length(unknowns(sys))
126+
@test first(sol2) isa Real
127+
128+
sol3 = sol(0.0:1.0:10.0, idxs = [lorenz1.x, lorenz2.x])
129+
130+
sol7 = sol(0.0:1.0:10.0, idxs = [2, 1])
131+
@test sol7.u isa Vector
132+
@test first(sol7.u) isa Vector
133+
@test length(sol7.u) == 11
134+
@test length(sol7.t) == 11
135+
@test collect(sol7[t]) sol3.t
136+
@test collect(sol7[t, 1:5]) sol3.t[1:5]
137+
138+
sol8 = sol(0.1, idxs = [2, 1])
139+
@test sol8 isa Vector
140+
@test length(sol8) == 2
141+
@test first(sol8) isa Real
142+
143+
sol9 = sol(0.0:1.0:10.0, idxs = 2)
144+
@test sol9.u isa Vector
145+
@test first(sol9.u) isa Real
146+
@test length(sol9.u) == 11
147+
@test length(sol9.t) == 11
148+
@test collect(sol9[t]) sol3.t
149+
@test collect(sol9[t, 1:5]) sol3.t[1:5]
150+
151+
sol10 = sol(0.1, idxs = 2)
152+
@test sol10 isa Real
153+
end
150154

151155
@testset "Plot idxs" begin
152156
@variables x(t) y(t)

0 commit comments

Comments
 (0)