Skip to content

Commit 3f7ad46

Browse files
Merge pull request #2931 from isaacsas/auto_alg_jsys_support
update JumpSystem for auto-alg support
2 parents adb6a84 + 4ef25cb commit 3f7ad46

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ FunctionWrappersWrappers = "0.1"
8989
Graphs = "1.5.2"
9090
InteractiveUtils = "1"
9191
JuliaFormatter = "1.0.47"
92-
JumpProcesses = "9.1"
92+
JumpProcesses = "9.13.1"
9393
LabelledArrays = "1.3"
9494
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16"
9595
Libdl = "1"

src/systems/jumps/jumpsystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ jprob = JumpProblem(complete(js), dprob, Direct())
426426
sol = solve(jprob, SSAStepper())
427427
```
428428
"""
429-
function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; callback = nothing,
429+
function JumpProcesses.JumpProblem(js::JumpSystem, prob,
430+
aggregator = JumpProcesses.NullAggregator(); callback = nothing,
430431
eval_expression = false, eval_module = @__MODULE__, kwargs...)
431432
if !iscomplete(js)
432433
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `JumpProblem`")
@@ -448,7 +449,8 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; callback =
448449
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
449450
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)
450451

451-
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator)
452+
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator) ||
453+
(aggregator isa JumpProcesses.NullAggregator)
452454
jdeps = asgraph(js)
453455
vdeps = variable_dependencies(js)
454456
vtoj = jdeps.badjlist

test/jumpsystem.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,21 @@ parammap = [β => 0.1 / 1000, γ => 0.01]
6969
dprob = DiscreteProblem(js2, u₀map, tspan, parammap)
7070
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
7171
Nsims = 30000
72-
function getmean(jprob, Nsims)
72+
function getmean(jprob, Nsims; use_stepper = true)
7373
m = 0.0
7474
for i in 1:Nsims
75-
sol = solve(jprob, SSAStepper())
75+
sol = use_stepper ? solve(jprob, SSAStepper()) : solve(jprob)
7676
m += sol[end, end]
7777
end
7878
m / Nsims
7979
end
8080
m = getmean(jprob, Nsims)
8181

82+
# test auto-alg selection works
83+
jprobb = JumpProblem(js2, dprob; save_positions = (false, false), rng)
84+
mb = getmean(jprobb, Nsims; use_stepper = false)
85+
@test abs(m - mb) / m < 0.01
86+
8287
@variables S2(t)
8388
obs = [S2 ~ 2 * S]
8489
@named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs)
@@ -89,7 +94,6 @@ sol = solve(jprob, SSAStepper(), saveat = tspan[2] / 10)
8994
@test all(2 .* sol[S] .== sol[S2])
9095

9196
# test save_positions is working
92-
9397
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
9498
sol = solve(jprob, SSAStepper(), saveat = 1.0)
9599
@test all((sol.t) .== collect(0.0:tspan[2]))
@@ -270,3 +274,22 @@ affect = [X ~ X - 1]
270274

271275
j1 = ConstantRateJump(k, [X ~ X - 1])
272276
@test_nowarn @mtkbuild js1 = JumpSystem([j1], t, [X], [k])
277+
278+
# test correct autosolver is selected, which implies appropriate dep graphs are available
279+
let
280+
@parameters k
281+
@variables X(t)
282+
rate = k
283+
affect = [X ~ X - 1]
284+
j1 = ConstantRateJump(k, [X ~ X - 1])
285+
286+
Nv = [1, JumpProcesses.USE_DIRECT_THRESHOLD + 1, JumpProcesses.USE_RSSA_THRESHOLD + 1]
287+
algtypes = [Direct, RSSA, RSSACR]
288+
for (N, algtype) in zip(Nv, algtypes)
289+
@named jsys = JumpSystem([deepcopy(j1) for _ in 1:N], t, [X], [k])
290+
jsys = complete(jsys)
291+
dprob = DiscreteProblem(jsys, [X => 10], (0.0, 10.0), [k => 1])
292+
jprob = JumpProblem(jsys, dprob)
293+
@test jprob.aggregator isa algtype
294+
end
295+
end

0 commit comments

Comments
 (0)