Skip to content

Commit 782c430

Browse files
authored
Merge pull request #520 from isaacsas/add_leap_aggregator
Add leap aggregator
2 parents f2d3efe + dc47666 commit 782c430

File tree

7 files changed

+160
-29
lines changed

7 files changed

+160
-29
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ FunctionWrappers = "1.1"
4545
Graphs = "1.11"
4646
KernelAbstractions = "0.9"
4747
LinearAlgebra = "1"
48-
LinearSolve = "2, 3"
48+
LinearSolve = "3"
4949
OrdinaryDiffEq = "6"
5050
Pkg = "1"
5151
PoissonRandom = "0.4"

ext/JumpProcessesKernelAbstractionsExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem,
2222

2323
jump_prob = ensembleprob.prob
2424

25-
# boilerplate from SimpleTauLeaping method
26-
@assert isempty(jump_prob.jump_callback.continuous_callbacks) # still needs to be a regular jump
27-
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
25+
# Validate that this is a PureLeaping JumpProblem
26+
validate_pure_leaping_inputs(jump_prob, alg) ||
27+
error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only non-RegularJumps.")
2828
prob = jump_prob.prob
2929

3030
probs = [remake(jump_prob) for _ in 1:trajectories]

src/JumpProcesses.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,13 @@ export ExtendedJumpArray
110110
include("variable_rate.jl")
111111
export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW
112112

113+
"""
114+
Aggregator to indicate that individual jumps should also be handled via the leaping
115+
algorithm that is passed to solve.
116+
"""
117+
struct PureLeaping <: AbstractAggregatorAlgorithm end
118+
export PureLeaping
119+
113120
# core problem and timestepping
114121
include("problem.jl")
115122
export JumpProblem, SplitCoupledJumpProblem

src/problem.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,48 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
302302
jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, rng, solkwargs)
303303
end
304304

305+
# Special dispatch for PureLeaping aggregator - bypasses all aggregation
306+
function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet;
307+
save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ?
308+
(false, true) : (true, true),
309+
rng = DEFAULT_RNG, scale_rates = true, useiszero = true,
310+
spatial_system = nothing, hopping_constants = nothing,
311+
callback = nothing, kwargs...)
312+
313+
# Validate no spatial systems (not currently supported)
314+
(spatial_system !== nothing || hopping_constants !== nothing) &&
315+
error("PureLeaping does not currently support spatial problems.")
316+
317+
# Initialize the MassActionJump rate constants with the user parameters
318+
if using_params(jumps.massaction_jump)
319+
rates = jumps.massaction_jump.param_mapper(prob.p)
320+
maj = MassActionJump(rates, jumps.massaction_jump.reactant_stoch,
321+
jumps.massaction_jump.net_stoch,
322+
jumps.massaction_jump.param_mapper; scale_rates = scale_rates,
323+
useiszero = useiszero,
324+
nocopy = true)
325+
else
326+
maj = jumps.massaction_jump
327+
end
328+
329+
# For PureLeaping, all jumps are handled by the tau-leaping solver
330+
# No discrete jump aggregation or variable rate callbacks are created
331+
disc_agg = nothing
332+
jump_cbs = CallbackSet()
333+
334+
# Store all jump types for access by tau-leaping solver
335+
crjs = jumps.constant_jumps
336+
vrjs = jumps.variable_jumps
337+
338+
iip = isinplace_jump(prob, jumps.regular_jump)
339+
solkwargs = make_kwarg(; callback)
340+
341+
JumpProblem{iip, typeof(prob), typeof(aggregator), typeof(jump_cbs),
342+
typeof(disc_agg), typeof(crjs), typeof(vrjs), typeof(jumps.regular_jump),
343+
typeof(maj), typeof(rng), typeof(solkwargs)}(prob, aggregator, disc_agg,
344+
jump_cbs, crjs, vrjs, jumps.regular_jump, maj, rng, solkwargs)
345+
end
346+
305347
aggregator(jp::JumpProblem{iip, P, A}) where {iip, P, A} = A
306348

307349
@inline function extend_tstops!(tstops, jp::JumpProblem)

src/simple_regular_solve.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end
22

3+
function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
4+
if !(jump_prob.aggregator isa PureLeaping)
5+
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
6+
JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \
7+
Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release."
8+
end
9+
isempty(jump_prob.jump_callback.continuous_callbacks) &&
10+
isempty(jump_prob.jump_callback.discrete_callbacks) &&
11+
isempty(jump_prob.constant_jumps) &&
12+
isempty(jump_prob.variable_jumps) &&
13+
get_num_majumps(jump_prob.massaction_jump) == 0 &&
14+
jump_prob.regular_jump !== nothing
15+
end
16+
317
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
418
seed = nothing,
519
dt = error("dt is required for SimpleTauLeaping."))
6-
7-
# boilerplate from SimpleTauLeaping method
8-
@assert isempty(jump_prob.jump_callback.continuous_callbacks) # still needs to be a regular jump
9-
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
20+
validate_pure_leaping_inputs(jump_prob, alg) ||
21+
error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only non-RegularJumps.")
1022
prob = jump_prob.prob
1123
rng = DEFAULT_RNG
1224
(seed !== nothing) && seed!(rng, seed)
@@ -62,4 +74,3 @@ end
6274
function EnsembleGPUKernel()
6375
EnsembleGPUKernel(nothing, 0.0)
6476
end
65-

test/gpu/regular_jumps.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ let
3131

3232
prob_disc = DiscreteProblem(u0, tspan, p)
3333
rj = RegularJump(regular_rate, regular_c, 3)
34-
jump_prob = JumpProblem(prob_disc, Direct(), rj)
34+
jump_prob = JumpProblem(prob_disc, PureLeaping(), rj)
3535

3636
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(),
3737
EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0)
@@ -72,7 +72,7 @@ let
7272
# Create JumpProblem
7373
prob_disc = DiscreteProblem(u0, tspan, p)
7474
rj = RegularJump(regular_rate, regular_c, 3)
75-
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng = StableRNG(12345))
75+
jump_prob = JumpProblem(prob_disc, PureLeaping(), rj; rng = StableRNG(12345))
7676

7777
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(),
7878
EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0)

test/regular_jumps.jl

Lines changed: 89 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,7 @@ function regular_rate(out, u, p, t)
88
out[2] = 0.01u[2]
99
end
1010

11-
function regular_c(dc, u, p, t, mark)
12-
dc[1, 1] = -1
13-
dc[2, 1] = 1
14-
dc[2, 2] = -1
15-
dc[3, 2] = 1
16-
end
17-
18-
dc = zeros(3, 2)
19-
20-
rj = RegularJump(regular_rate, regular_c, dc; constant_c = true)
21-
jumps = JumpSet(rj)
22-
23-
prob = DiscreteProblem([999.0, 1.0, 0.0], (0.0, 250.0))
24-
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
25-
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
26-
27-
const _dc = zeros(3, 2)
11+
const dc = zeros(3, 2)
2812
dc[1, 1] = -1
2913
dc[2, 1] = 1
3014
dc[2, 2] = -1
@@ -37,5 +21,92 @@ end
3721
rj = RegularJump(regular_rate, regular_c, 2)
3822
jumps = JumpSet(rj)
3923
prob = DiscreteProblem([999, 1, 0], (0.0, 250.0))
40-
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
24+
jump_prob = JumpProblem(prob, PureLeaping(), rj; rng)
4125
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
26+
27+
# Test PureLeaping aggregator functionality
28+
@testset "PureLeaping Aggregator Tests" begin
29+
# Test with MassActionJump
30+
u0 = [10, 5, 0]
31+
tspan = (0.0, 10.0)
32+
p = [0.1, 0.2]
33+
prob = DiscreteProblem(u0, tspan, p)
34+
35+
# Create MassActionJump
36+
reactant_stoich = [[1 => 1], [1 => 2]]
37+
net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]]
38+
rates = [0.1, 0.05]
39+
maj = MassActionJump(rates, reactant_stoich, net_stoich)
40+
41+
# Test PureLeaping JumpProblem creation
42+
jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj))
43+
@test jp_pure.aggregator isa PureLeaping
44+
@test jp_pure.discrete_jump_aggregation === nothing
45+
@test jp_pure.massaction_jump !== nothing
46+
@test length(jp_pure.jump_callback.discrete_callbacks) == 0
47+
48+
# Test with ConstantRateJump
49+
rate(u, p, t) = p[1] * u[1]
50+
affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
51+
crj = ConstantRateJump(rate, affect!)
52+
53+
jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj))
54+
@test jp_pure_crj.aggregator isa PureLeaping
55+
@test jp_pure_crj.discrete_jump_aggregation === nothing
56+
@test length(jp_pure_crj.constant_jumps) == 1
57+
58+
# Test with VariableRateJump
59+
vrate(u, p, t) = t * p[1] * u[1]
60+
vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
61+
vrj = VariableRateJump(vrate, vaffect!)
62+
63+
jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj))
64+
@test jp_pure_vrj.aggregator isa PureLeaping
65+
@test jp_pure_vrj.discrete_jump_aggregation === nothing
66+
@test length(jp_pure_vrj.variable_jumps) == 1
67+
68+
# Test with RegularJump
69+
function rj_rate(out, u, p, t)
70+
out[1] = p[1] * u[1]
71+
end
72+
73+
rj_dc = zeros(3, 1)
74+
rj_dc[1, 1] = -1
75+
rj_dc[3, 1] = 1
76+
77+
function rj_c(du, u, p, t, counts, mark)
78+
mul!(du, rj_dc, counts)
79+
end
80+
81+
regj = RegularJump(rj_rate, rj_c, 1)
82+
83+
jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj))
84+
@test jp_pure_regj.aggregator isa PureLeaping
85+
@test jp_pure_regj.discrete_jump_aggregation === nothing
86+
@test jp_pure_regj.regular_jump !== nothing
87+
88+
# Test mixed jump types
89+
mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,),
90+
variable_jumps = (vrj,), regular_jumps = regj)
91+
jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps)
92+
@test jp_pure_mixed.aggregator isa PureLeaping
93+
@test jp_pure_mixed.discrete_jump_aggregation === nothing
94+
@test jp_pure_mixed.massaction_jump !== nothing
95+
@test length(jp_pure_mixed.constant_jumps) == 1
96+
@test length(jp_pure_mixed.variable_jumps) == 1
97+
@test jp_pure_mixed.regular_jump !== nothing
98+
99+
# Test spatial system error
100+
spatial_sys = CartesianGrid((2, 2))
101+
hopping_consts = [1.0]
102+
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj);
103+
spatial_system = spatial_sys)
104+
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj);
105+
hopping_constants = hopping_consts)
106+
107+
# Test MassActionJump with parameter mapping
108+
maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2])
109+
jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params))
110+
scaled_rates = [p[1], p[2]/2]
111+
@test jp_params.massaction_jump.scaled_rates == scaled_rates
112+
end

0 commit comments

Comments
 (0)