Skip to content

Commit 06aed2e

Browse files
authored
Merge pull request #1198 from SciML/add_simulation_type_preservation_tests
Add simulation type preservation tests
2 parents 0013818 + a54094c commit 06aed2e

File tree

4 files changed

+142
-0
lines changed

4 files changed

+142
-0
lines changed

test/simulation_and_solving/simulate_ODEs.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,39 @@ end
164164

165165
### Other Tests ###
166166

167+
# Checks that solution values have types consistent with their input types.
168+
# Check that both float types are preserved in the solution (and problems), while integers are
169+
# promoted to floats.
170+
# Checks that the time types are correct (`Float64` by default or possibly `Float32`).
171+
let
172+
# Create model. Checks when input type is `Float64` the produced values are also `Float64`.
173+
rn = @reaction_network begin
174+
(k1,k2), X1 <--> X2
175+
end
176+
u0 = [:X1 => 1.0, :X2 => 3.0]
177+
ps = [:k1 => 2.0, :k2 => 3.0]
178+
oprob = ODEProblem(rn, u0, 1.0, ps)
179+
osol = solve(oprob)
180+
@test eltype(osol[:X1]) == eltype(osol[:X2]) == typeof(oprob[:X1]) == typeof(oprob[:X2]) == Float64
181+
@test eltype(osol.t) == typeof(oprob.tspan[1]) == typeof(oprob.tspan[2]) == Float64
182+
183+
# Checks that `Int64` values are promoted to `Float64`.
184+
u0 = [:X1 => 1, :X2 => 3]
185+
ps = [:k1 => 2, :k2 => 3]
186+
oprob = ODEProblem(rn, u0, 1, ps)
187+
osol = solve(oprob)
188+
@test eltype(osol[:X1]) == eltype(osol[:X2]) == typeof(oprob[:X1]) == typeof(oprob[:X2]) == Float64
189+
@test eltype(osol.t) == Float64
190+
191+
# Checks when values are `Float32` (a valid type and should be preserved).
192+
u0 = [:X1 => 1.0f0, :X2 => 3.0f0]
193+
ps = [:k1 => 2.0f0, :k2 => 3.0f0]
194+
oprob = ODEProblem(rn, u0, 1.0f0, ps)
195+
osol = solve(oprob)
196+
@test eltype(osol[:X1]) == eltype(osol[:X2]) == typeof(oprob[:X1]) == typeof(oprob[:X2]) == Float32
197+
@test eltype(osol.t) == typeof(oprob.tspan[1]) == typeof(oprob.tspan[2]) == Float32
198+
end
199+
167200
# Tests simulating a network without parameters.
168201
let
169202
no_param_network = @reaction_network begin

test/simulation_and_solving/simulate_SDEs.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,40 @@ end
383383

384384
### Other Tests ###
385385

386+
# Checks that solution values have types consistent with their input types.
387+
# Check that both float types are preserved in the solution (and problems), while integers are
388+
# promoted to floats.
389+
# Checks that the time types are correct (`Float64` by default or possibly `Float32`), however,
390+
# type conversion only occurs in the solution, and integer types are preserved in problems.
391+
let
392+
# Create model. Checks when input type is `Float64` the produced values are also `Float64`.
393+
rn = @reaction_network begin
394+
(k1,k2), X1 <--> X2
395+
end
396+
u0 = [:X1 => 1.0, :X2 => 3.0]
397+
ps = [:k1 => 2.0, :k2 => 3.0]
398+
sprob = SDEProblem(rn, u0, 1.0, ps)
399+
ssol = solve(sprob, ISSEM())
400+
@test eltype(ssol[:X1]) == eltype(ssol[:X2]) == typeof(sprob[:X1]) == typeof(sprob[:X2]) == Float64
401+
@test eltype(ssol.t) == typeof(sprob.tspan[1]) == typeof(sprob.tspan[2]) == Float64
402+
403+
# Checks that `Int64` values are promoted to `Float64`.
404+
u0 = [:X1 => 1, :X2 => 3]
405+
ps = [:k1 => 2, :k2 => 3]
406+
sprob = SDEProblem(rn, u0, 1, ps)
407+
ssol = solve(sprob, ISSEM())
408+
@test eltype(ssol[:X1]) == eltype(ssol[:X2]) == typeof(sprob[:X1]) == typeof(sprob[:X2]) == Float64
409+
@test eltype(ssol.t) == Float64
410+
411+
# Checks when values are `Float32` (a valid type and should be preserved).
412+
u0 = [:X1 => 1.0f0, :X2 => 3.0f0]
413+
ps = [:k1 => 2.0f0, :k2 => 3.0f0]
414+
sprob = SDEProblem(rn, u0, 1.0f0, ps)
415+
ssol = solve(sprob, ISSEM())
416+
@test eltype(ssol[:X1]) == eltype(ssol[:X2]) == typeof(sprob[:X1]) == typeof(sprob[:X2]) == Float32
417+
@test eltype(ssol.t) == typeof(sprob.tspan[1]) == typeof(sprob.tspan[2]) == Float32
418+
end
419+
386420
# Tests simulating a network without parameters.
387421
let
388422
no_param_network = @reaction_network begin

test/simulation_and_solving/simulate_jumps.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,47 @@ let
220220
@test (means1[1] - means1[2]) < .1 * means1[1]
221221
@test (means2[1] - means2[2]) < .1 * means2[1]
222222
end
223+
224+
### Other Tests ###
225+
226+
# Checks that solution values have types consistent with their input types.
227+
# Check that both float and integer types are preserved in the solution (and problems).
228+
# Checks that the time types are correct (`Float64` by default or possibly `Float32`).
229+
# `JumpInputs` currently does not support integer time spans. When it does, we will check that
230+
# these produce `Float64` time values.
231+
let
232+
# Create model. Checks when input type is `Float64` the produced values are also `Float64`.
233+
rn = @reaction_network begin
234+
(k1,k2), X1 <--> X2
235+
end
236+
u0 = [:X1 => 1.0, :X2 => 3.0]
237+
ps = [:k1 => 2.0, :k2 => 3.0]
238+
jprob = JumpProblem(JumpInputs(rn, u0, (0.0, 1.0), ps))
239+
jsol = solve(jprob)
240+
@test eltype(jsol[:X1]) == eltype(jsol[:X2]) == typeof(jprob[:X1]) == typeof(jprob[:X2]) == Float64
241+
@test eltype(jsol.t) == typeof(jprob.prob.tspan[1]) == typeof(jprob.prob.tspan[2]) == Float64
242+
243+
# Checks that `Int64` gives `Int64` species values.
244+
u0 = [:X1 => 1 :X2 => 3]
245+
ps = [:k1 => 2, :k2 => 3]
246+
jprob = JumpProblem(JumpInputs(rn, u0, (0.0, 1.0), ps))
247+
jsol = solve(jprob)
248+
@test eltype(jsol[:X1]) == eltype(jsol[:X2]) == typeof(jprob[:X1]) == typeof(jprob[:X2]) == Int64
249+
@test eltype(jsol.t) == typeof(jprob.prob.tspan[1]) == typeof(jprob.prob.tspan[2]) == Float64
250+
251+
# Checks when values are `Float32` (a valid type and should be preserved).
252+
u0 = [:X1 => 1.0f0, :X2 => 3.0f0]
253+
ps = [:k1 => 2.0f0, :k2 => 3.0f0]
254+
jprob = JumpProblem(JumpInputs(rn, u0, (0.0f0, 1.0f0), ps))
255+
jsol = solve(jprob)
256+
@test eltype(jsol[:X1]) == eltype(jsol[:X2]) == typeof(jprob[:X1]) == typeof(jprob[:X2]) == Float32
257+
@test eltype(jsol.t) == typeof(jprob.prob.tspan[1]) == typeof(jprob.prob.tspan[2]) == Float32
258+
259+
# Checks when values are `Int32` (a valid species type and should be preserved).
260+
u0 = [:X1 => Int32(1), :X2 => Int32(3)]
261+
ps = [:k1 => Int32(2), :k2 => Int32(3)]
262+
jprob = JumpProblem(JumpInputs(rn, u0, (0.0, 1.0), ps))
263+
jsol = solve(jprob)
264+
@test eltype(jsol[:X1]) == eltype(jsol[:X2]) == typeof(jprob[:X1]) == typeof(jprob[:X2]) == Int32
265+
@test eltype(jsol.t) == typeof(jprob.prob.tspan[1]) == typeof(jprob.prob.tspan[2]) == Float64
266+
end

test/simulation_and_solving/solve_nonlinear.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,34 @@ end
101101
@test f_eval(steady_state_network_3, [:X => sol1[X], :Y => sol1[Y], :Y2 => sol1[Y2], :XY2 => sol1[XY2]], p, 0.0) [0.0, 0.0, 0.0, 0.0] atol=1e-10
102102
@test f_eval(steady_state_network_3, [:X => sol2[X], :Y => sol2[Y], :Y2 => sol2[Y2], :XY2 => sol2[XY2]], p, 0.0) [0.0, 0.0, 0.0, 0.0] atol=1e-10
103103
end
104+
105+
### Other Tests ###
106+
107+
# Checks that solution values have types consistent with their input types.
108+
# Check for values that types that should be preserved (`Float64` and `Float32`) and types
109+
# that should be converted to the default (conversion of `Int64 to `Float64`).
110+
let
111+
# Create model. Checks when input type is `Float64` that the problem and solution types are `Float64`.
112+
rn = @reaction_network begin
113+
(k1,k2), X1 <--> X2
114+
end
115+
u0 = [:X1 => 1.0, :X2 => 3.0]
116+
ps = [:k1 => 2.0, :k2 => 3.0]
117+
nlprob = NonlinearProblem(rn, u0, ps)
118+
nlsol = solve(nlprob)
119+
@test eltype(nlsol[:X1]) == eltype(nlsol[:X2]) == typeof(nlprob[:X1]) == typeof(nlprob[:X2]) == Float64
120+
121+
# Checks that input type `Int64` is converted to `Float64`.
122+
u0 = [:X1 => 1, :X2 => 3]
123+
ps = [:k1 => 2, :k2 => 3]
124+
nlprob = NonlinearProblem(rn, u0, ps)
125+
nlsol = solve(nlprob)
126+
@test eltype(nlsol[:X1]) == eltype(nlsol[:X2]) == typeof(nlprob[:X1]) == typeof(nlprob[:X2]) == Float64
127+
128+
# Checks that input type `Float32` is preserved
129+
u0 = [:X1 => 1.0f0, :X2 => 3.0f0]
130+
ps = [:k1 => 2.0f0, :k2 => 3.0f0]
131+
nlprob = NonlinearProblem(rn, u0, ps)
132+
nlsol = solve(nlprob)
133+
@test eltype(nlsol[:X1]) == eltype(nlsol[:X2]) == typeof(nlprob[:X1]) == typeof(nlprob[:X2]) == Float32
134+
end

0 commit comments

Comments
 (0)