Skip to content

Commit 4dc605e

Browse files
test: create JumpProblem directly
1 parent 7c69ba3 commit 4dc605e

File tree

2 files changed

+41
-41
lines changed

2 files changed

+41
-41
lines changed

test/jumpsystem.jl

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ p = (0.1 / 1000, 0.01);
6868
tspan = (0.0, 250.0);
6969
u₀map = [S => 999, I => 1, R => 0]
7070
parammap ==> 0.1 / 1000, γ => 0.01]
71-
dprob = DiscreteProblem(js2, u₀map, tspan, parammap)
72-
jprob = JumpProblem(js2, dprob, Direct(); save_positions = (false, false), rng)
71+
jprob = JumpProblem(js2, u₀map, tspan, parammap; aggregator = Direct(),
72+
save_positions = (false, false), rng)
73+
@test jprob.prob isa DiscreteProblem
7374
Nsims = 30000
7475
function getmean(jprob, Nsims; use_stepper = true)
7576
m = 0.0
@@ -82,21 +83,23 @@ end
8283
m = getmean(jprob, Nsims)
8384

8485
# test auto-alg selection works
85-
jprobb = JumpProblem(js2, dprob; save_positions = (false, false), rng)
86+
jprobb = JumpProblem(js2, u₀map, tspan, parammap; save_positions = (false, false), rng)
8687
mb = getmean(jprobb, Nsims; use_stepper = false)
8788
@test abs(m - mb) / m < 0.01
8889

8990
@variables S2(t)
9091
obs = [S2 ~ 2 * S]
9192
@named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs)
9293
js2b = complete(js2b)
93-
dprob = DiscreteProblem(js2b, u₀map, tspan, parammap)
94-
jprob = JumpProblem(js2b, dprob, Direct(); save_positions = (false, false), rng)
94+
jprob = JumpProblem(js2b, u₀map, tspan, parammap; aggregator = Direct(),
95+
save_positions = (false, false), rng)
96+
@test jprob.prob isa DiscreteProblem
9597
sol = solve(jprob, SSAStepper(); saveat = tspan[2] / 10)
9698
@test all(2 .* sol[S] .== sol[S2])
9799

98100
# test save_positions is working
99-
jprob = JumpProblem(js2, dprob, Direct(); save_positions = (false, false), rng)
101+
jprob = JumpProblem(js2, u₀map, tspan, parammap; aggregator = Direct(),
102+
save_positions = (false, false), rng)
100103
sol = solve(jprob, SSAStepper(); saveat = 1.0)
101104
@test all((sol.t) .== collect(0.0:tspan[2]))
102105

@@ -142,18 +145,20 @@ maj1 = MassActionJump(2 * β / 2, [S => 1, I => 1], [S => -1, I => 1])
142145
maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
143146
@named js3 = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ])
144147
js3 = complete(js3)
145-
dprob = DiscreteProblem(js3, u₀map, tspan, parammap)
146-
jprob = JumpProblem(js3, dprob, Direct(); rng)
148+
jprob = JumpProblem(js3, u₀map, tspan, parammap; aggregator = Direct(), rng)
149+
@test jprob.prob isa DiscreteProblem
147150
m3 = getmean(jprob, Nsims)
148151
@test abs(m - m3) / m < 0.01
149152

150153
# maj jump test with various dep graphs
151154
@named js3b = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ])
152155
js3b = complete(js3b)
153-
jprobb = JumpProblem(js3b, dprob, NRM(); rng)
156+
jprobb = JumpProblem(js3b, u₀map, tspan, parammap; aggregator = NRM(), rng)
157+
@test jprobb.prob isa DiscreteProblem
154158
m4 = getmean(jprobb, Nsims)
155159
@test abs(m - m4) / m < 0.01
156-
jprobc = JumpProblem(js3b, dprob, RSSA(); rng)
160+
jprobc = JumpProblem(js3b, u₀map, tspan, parammap; aggregator = RSSA(), rng)
161+
@test jprobc.prob isa DiscreteProblem
157162
m4 = getmean(jprobc, Nsims)
158163
@test abs(m - m4) / m < 0.01
159164

@@ -162,8 +167,9 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
162167
maj2 = MassActionJump(γ, [S => 1], [S => -1])
163168
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])
164169
js4 = complete(js4)
165-
dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01])
166-
jprob = JumpProblem(js4, dprob, Direct(); rng)
170+
jprob = JumpProblem(
171+
js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01]; aggregator = Direct(), rng)
172+
@test jprob.prob isa DiscreteProblem
167173
m4 = getmean(jprob, Nsims)
168174
@test abs(m4 - 2.0 / 0.01) * 0.01 / 2.0 < 0.01
169175

@@ -172,8 +178,9 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
172178
maj2 = MassActionJump(γ, [S => 2], [S => -1])
173179
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])
174180
js4 = complete(js4)
175-
dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01])
176-
jprob = JumpProblem(js4, dprob, Direct(); rng)
181+
jprob = JumpProblem(
182+
js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01]; aggregator = Direct(), rng)
183+
@test jprob.prob isa DiscreteProblem
177184
sol = solve(jprob, SSAStepper());
178185

179186
# issue #819
@@ -195,8 +202,9 @@ let
195202
p = [k1 => 2.0, k2 => 0.0, k3 => 0.5]
196203
u₀ = [A => 100, B => 0]
197204
tspan = (0.0, 2000.0)
198-
dprob = DiscreteProblem(js5, u₀, tspan, p)
199-
jprob = JumpProblem(js5, dprob, Direct(); save_positions = (false, false), rng)
205+
jprob = JumpProblem(
206+
js5, u₀, tspan, p; aggregator = Direct(), save_positions = (false, false), rng)
207+
@test jprob.prob isa DiscreteProblem
200208
@test all(jprob.massaction_jump.scaled_rates .== [1.0, 0.0])
201209

202210
pcondit(u, t, integrator) = t == 1000.0
@@ -259,15 +267,10 @@ u0 = [X => 10]
259267
tspan = (0.0, 1.0)
260268
ps = [k => 1.0]
261269

262-
dp1 = DiscreteProblem(js1, u0, tspan, ps)
263-
dp2 = DiscreteProblem(js2, u0, tspan)
264-
dp3 = DiscreteProblem(js3, u0, tspan, ps)
265-
dp4 = DiscreteProblem(js4, u0, tspan)
266-
267-
@test_nowarn jp1 = JumpProblem(js1, dp1, Direct())
268-
@test_nowarn jp2 = JumpProblem(js2, dp2, Direct())
269-
@test_nowarn jp3 = JumpProblem(js3, dp3, Direct())
270-
@test_nowarn jp4 = JumpProblem(js4, dp4, Direct())
270+
@test_nowarn jp1 = JumpProblem(js1, u0, tspan, ps; aggregator = Direct())
271+
@test_nowarn jp2 = JumpProblem(js2, u0, tspan; aggregator = Direct())
272+
@test_nowarn jp3 = JumpProblem(js3, u0, tspan, ps; aggregator = Direct())
273+
@test_nowarn jp4 = JumpProblem(js4, u0, tspan; aggregator = Direct())
271274

272275
# Ensure `structural_simplify` (and `@mtkbuild`) works on JumpSystem (by doing nothing)
273276
# Issue#2558
@@ -292,8 +295,7 @@ let
292295
for (N, algtype) in zip(Nv, algtypes)
293296
@named jsys = JumpSystem([deepcopy(j1) for _ in 1:N], t, [X], [k])
294297
jsys = complete(jsys)
295-
dprob = DiscreteProblem(jsys, [X => 10], (0.0, 10.0), [k => 1])
296-
jprob = JumpProblem(jsys, dprob)
298+
jprob = JumpProblem(jsys, [X => 10], (0.0, 10.0), [k => 1])
297299
@test jprob.aggregator isa algtype
298300
end
299301
end
@@ -306,8 +308,9 @@ let
306308
@parameters k
307309
vrj = VariableRateJump(k * (sin(t) + 1), [A ~ A + 1, C ~ C + 2])
308310
js = complete(JumpSystem([vrj], t, [A, C], [k]; name = :js, observed = [B ~ C * A]))
309-
oprob = ODEProblem(js, [A => 0, C => 0], (0.0, 10.0), [k => 1.0])
310-
jprob = JumpProblem(js, oprob, Direct(); rng)
311+
jprob = JumpProblem(
312+
js, [A => 0, C => 0], (0.0, 10.0), [k => 1.0]; aggregtor = Direct(), rng)
313+
@test jprob.prob isa ODEProblem
311314
sol = solve(jprob, Tsit5())
312315

313316
# test observed and symbolic indexing work
@@ -439,8 +442,7 @@ let
439442
k2val = 20.0
440443
p = [k1 => k1val, k2 => k2val]
441444
tspan = (0.0, 10.0)
442-
oprob = ODEProblem(jsys, u0, tspan, p)
443-
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))
445+
jprob = JumpProblem(jsys, u0, tspan, p; rng, save_positions = (false, false))
444446

445447
times = range(0.0, tspan[2], length = 100)
446448
Nsims = 4000
@@ -479,8 +481,7 @@ let
479481
u0map = [X => p.X₀, Y => p.Y₀]
480482
pmap ==> p.α, β => p.β]
481483
tspan = (0.0, 20.0)
482-
oprob = ODEProblem(jsys, u0map, tspan, pmap)
483-
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))
484+
jprob = JumpProblem(jsys, u0, tspan, pmap; rng, save_positions = (false, false))
484485
times = range(0.0, tspan[2], length = 100)
485486
Nsims = 4000
486487
Xv = zeros(length(times))
@@ -518,8 +519,7 @@ let
518519
continuous_events = cevents)
519520
jsys = complete(jsys)
520521
tspan = (0.0, 200.0)
521-
oprob = ODEProblem(jsys, u0map, tspan, pmap)
522-
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))
522+
jprob = JumpProblem(jsys, u0, tspan, pmap; rng, save_positions = (false, false))
523523
Xsamp = 0.0
524524
Nsims = 4000
525525
for n in 1:Nsims
@@ -544,8 +544,8 @@ end
544544

545545
# Works.
546546
@mtkbuild js = JumpSystem([j1, j2], t, [X], [p, d])
547-
dprob = DiscreteProblem(js, [X => 15], (0.0, 10.0), [p => 2.0, d => 0.5])
548-
jprob = JumpProblem(js, dprob, Direct())
547+
jprob = JumpProblem(
548+
js, [X => 15], (0.0, 10.0), [p => 2.0, d => 0.5]; aggregator = Direct())
549549
sol = solve(jprob, SSAStepper())
550550
@test eltype(sol[X]) === Int64
551551
end

test/parameter_dependencies.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ end
300300
tspan = (0.0, 250.0)
301301
u₀map = [S => 999, I => 1, R => 0]
302302
parammap ==> 0.01]
303-
dprob = DiscreteProblem(js2, u₀map, tspan, parammap)
304-
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
303+
jprob = JumpProblem(js2, u₀map, tspan, parammap; aggregator = Direct(),
304+
save_positions = (false, false), rng = rng)
305305
@test jprob.ps[γ] == 0.01
306306
@test jprob.ps[β] == 0.0001
307307
@test_nowarn solve(jprob, SSAStepper())
@@ -310,8 +310,8 @@ end
310310
[j₁, j₃], t, [S, I, R], [γ]; parameter_dependencies ==> 0.01γ],
311311
discrete_events = [[10.0] =>~ 0.02]])
312312
js2 = complete(js2)
313-
dprob = DiscreteProblem(js2, u₀map, tspan, parammap)
314-
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
313+
jprob = JumpProblem(js2, u₀map, tspan, parammap; aggregator = Direct(),
314+
save_positions = (false, false), rng = rng)
315315
integ = init(jprob, SSAStepper())
316316
@test integ.ps[γ] == 0.01
317317
@test integ.ps[β] == 0.0001

0 commit comments

Comments
 (0)