Skip to content

Commit 6c2ef68

Browse files
committed
ODEProblem for vrjs and add back test
1 parent c10d00b commit 6c2ef68

File tree

2 files changed

+88
-32
lines changed

2 files changed

+88
-32
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ end
311311
```julia
312312
DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
313313
parammap = DiffEqBase.NullParameters;
314-
use_union = false,
314+
use_union = true,
315315
kwargs...)
316316
```
317317
@@ -331,7 +331,6 @@ dprob = DiscreteProblem(complete(js), u₀map, tspan, parammap)
331331
"""
332332
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
333333
parammap = DiffEqBase.NullParameters();
334-
checkbounds = false,
335334
use_union = true,
336335
eval_expression = false,
337336
eval_module = @__MODULE__,
@@ -385,7 +384,7 @@ struct DiscreteProblemExpr{iip} end
385384

386385
function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
387386
parammap = DiffEqBase.NullParameters();
388-
use_union = false,
387+
use_union = true,
389388
kwargs...) where {iip}
390389
if !iscomplete(sys)
391390
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
@@ -412,6 +411,61 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
412411
end
413412
end
414413

414+
"""
415+
```julia
416+
DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan,
417+
parammap = DiffEqBase.NullParameters;
418+
use_union = true,
419+
kwargs...)
420+
```
421+
422+
Generates a blank ODEProblem for a pure jump JumpSystem to utilize as its `prob.prob`. This
423+
is used in the case where there are no ODEs and no SDEs associated with the system but there
424+
are jumps with an explicit time dependency (i.e. `VariableRateJump`s). If no jumps have an
425+
explicit time dependence, i.e. all are `ConstantRateJump`s or `MassActionJump`s then
426+
`DiscreteProblem` should be preferred for performance reasons.
427+
428+
Continuing the example from the [`JumpSystem`](@ref) definition:
429+
430+
```julia
431+
using DiffEqBase, JumpProcesses
432+
u₀map = [S => 999, I => 1, R => 0]
433+
parammap = [β => 0.1 / 1000, γ => 0.01]
434+
tspan = (0.0, 250.0)
435+
oprob = ODEProblem(complete(js), u₀map, tspan, parammap)
436+
```
437+
"""
438+
function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
439+
parammap = DiffEqBase.NullParameters();
440+
use_union = true,
441+
eval_expression = false,
442+
eval_module = @__MODULE__,
443+
kwargs...)
444+
if !iscomplete(sys)
445+
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
446+
end
447+
dvs = unknowns(sys)
448+
ps = parameters(sys)
449+
450+
defs = defaults(sys)
451+
defs = mergedefaults(defs, parammap, ps)
452+
defs = mergedefaults(defs, u0map, dvs)
453+
454+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
455+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
456+
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
457+
else
458+
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
459+
end
460+
461+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
462+
463+
f = (du, u, p, t) -> (du .= 0; nothing)
464+
df = ODEFunction(f; sys = sys, observed = observedfun)
465+
ODEProblem(df, u0, tspan, p; kwargs...)
466+
end
467+
468+
415469
"""
416470
```julia
417471
DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)

test/jumpsystem.jl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ tspan = (0.0, 250.0);
6767
u₀map = [S => 999, I => 1, R => 0]
6868
parammap ==> 0.1 / 1000, γ => 0.01]
6969
dprob = DiscreteProblem(js2, u₀map, tspan, parammap)
70-
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
70+
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng)
7171
Nsims = 30000
7272
function getmean(jprob, Nsims; use_stepper = true)
7373
m = 0.0
@@ -89,12 +89,12 @@ obs = [S2 ~ 2 * S]
8989
@named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs)
9090
js2b = complete(js2b)
9191
dprob = DiscreteProblem(js2b, u₀map, tspan, parammap)
92-
jprob = JumpProblem(js2b, dprob, Direct(), save_positions = (false, false), rng = rng)
92+
jprob = JumpProblem(js2b, dprob, Direct(), save_positions = (false, false), rng)
9393
sol = solve(jprob, SSAStepper(), saveat = tspan[2] / 10)
9494
@test all(2 .* sol[S] .== sol[S2])
9595

9696
# test save_positions is working
97-
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
97+
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng)
9898
sol = solve(jprob, SSAStepper(), saveat = 1.0)
9999
@test all((sol.t) .== collect(0.0:tspan[2]))
100100

@@ -129,7 +129,7 @@ function a2!(integrator)
129129
end
130130
j2 = ConstantRateJump(r2, a2!)
131131
jset = JumpSet((), (j1, j2), nothing, nothing)
132-
jprob = JumpProblem(prob, Direct(), jset, save_positions = (false, false), rng = rng)
132+
jprob = JumpProblem(prob, Direct(), jset, save_positions = (false, false), rng)
133133
m2 = getmean(jprob, Nsims)
134134

135135
# test JumpSystem solution agrees with direct version
@@ -141,17 +141,17 @@ maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
141141
@named js3 = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ])
142142
js3 = complete(js3)
143143
dprob = DiscreteProblem(js3, u₀map, tspan, parammap)
144-
jprob = JumpProblem(js3, dprob, Direct(), rng = rng)
144+
jprob = JumpProblem(js3, dprob, Direct(), rng)
145145
m3 = getmean(jprob, Nsims)
146146
@test abs(m - m3) / m < 0.01
147147

148148
# maj jump test with various dep graphs
149149
@named js3b = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ])
150150
js3b = complete(js3b)
151-
jprobb = JumpProblem(js3b, dprob, NRM(), rng = rng)
151+
jprobb = JumpProblem(js3b, dprob, NRM(), rng)
152152
m4 = getmean(jprobb, Nsims)
153153
@test abs(m - m4) / m < 0.01
154-
jprobc = JumpProblem(js3b, dprob, RSSA(), rng = rng)
154+
jprobc = JumpProblem(js3b, dprob, RSSA(), rng)
155155
m4 = getmean(jprobc, Nsims)
156156
@test abs(m - m4) / m < 0.01
157157

@@ -161,7 +161,7 @@ maj2 = MassActionJump(γ, [S => 1], [S => -1])
161161
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])
162162
js4 = complete(js4)
163163
dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01])
164-
jprob = JumpProblem(js4, dprob, Direct(), rng = rng)
164+
jprob = JumpProblem(js4, dprob, Direct(), rng)
165165
m4 = getmean(jprob, Nsims)
166166
@test abs(m4 - 2.0 / 0.01) * 0.01 / 2.0 < 0.01
167167

@@ -171,7 +171,7 @@ maj2 = MassActionJump(γ, [S => 2], [S => -1])
171171
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])
172172
js4 = complete(js4)
173173
dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01])
174-
jprob = JumpProblem(js4, dprob, Direct(), rng = rng)
174+
jprob = JumpProblem(js4, dprob, Direct(), rng)
175175
sol = solve(jprob, SSAStepper());
176176

177177
# issue #819
@@ -183,28 +183,30 @@ sol = solve(jprob, SSAStepper());
183183
end
184184

185185
# test if param mapper is setup correctly for callbacks
186-
@parameters k1 k2 k3
187-
@variables A(t) B(t)
188-
maj1 = MassActionJump(k1 * k3, [0 => 1], [A => -1, B => 1])
189-
maj2 = MassActionJump(k2, [B => 1], [A => 1, B => -1])
190-
@named js5 = JumpSystem([maj1, maj2], t, [A, B], [k1, k2, k3])
191-
js5 = complete(js5)
192-
p = [k1 => 2.0, k2 => 0.0, k3 => 0.5]
193-
u₀ = [A => 100, B => 0]
194-
tspan = (0.0, 2000.0)
195-
dprob = DiscreteProblem(js5, u₀, tspan, p)
196-
jprob = JumpProblem(js5, dprob, Direct(), save_positions = (false, false), rng = rng)
197-
@test all(jprob.massaction_jump.scaled_rates .== [1.0, 0.0])
186+
let
187+
@parameters k1 k2 k3
188+
@variables A(t) B(t)
189+
maj1 = MassActionJump(k1 * k3, [0 => 1], [A => -1, B => 1])
190+
maj2 = MassActionJump(k2, [B => 1], [A => 1, B => -1])
191+
@named js5 = JumpSystem([maj1, maj2], t, [A, B], [k1, k2, k3])
192+
js5 = complete(js5)
193+
p = [k1 => 2.0, k2 => 0.0, k3 => 0.5]
194+
u₀ = [A => 100, B => 0]
195+
tspan = (0.0, 2000.0)
196+
dprob = DiscreteProblem(js5, u₀, tspan, p)
197+
jprob = JumpProblem(js5, dprob, Direct(); save_positions = (false, false), rng)
198+
@test all(jprob.massaction_jump.scaled_rates .== [1.0, 0.0])
198199

199-
pcondit(u, t, integrator) = t == 1000.0
200-
function paffect!(integrator)
201-
integrator.ps[k1] = 0.0
202-
integrator.ps[k2] = 1.0
203-
reset_aggregated_jumps!(integrator)
200+
pcondit(u, t, integrator) = t == 1000.0
201+
function paffect!(integrator)
202+
integrator.ps[k1] = 0.0
203+
integrator.ps[k2] = 1.0
204+
reset_aggregated_jumps!(integrator)
205+
end
206+
cb = DiscreteCallback(pcondit, paffect!)
207+
sol = solve(jprob, SSAStepper(); tstops = [1000.0], callback = cb)
208+
@test sol.u[end][1] == 100
204209
end
205-
sol = solve(jprob, SSAStepper(), tstops = [1000.0],
206-
callback = DiscreteCallback(pcondit, paffect!))
207-
@test_skip sol.u[end][1] == 100 # TODO: Fix mass-action jumps in JumpProcesses
208210

209211
# observed variable handling
210212
@variables OBS(t)

0 commit comments

Comments
 (0)