Skip to content

Commit c2283a2

Browse files
committed
add constant_rate field to JumpProblem
1 parent 442f392 commit c2283a2

File tree

3 files changed

+30
-28
lines changed

3 files changed

+30
-28
lines changed

src/JumpProcesses.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ include("SSA_stepper.jl")
121121
export SSAStepper
122122

123123
# leaping:
124+
"""
125+
Aggregator to indicate that individual jumps should also be handled via the leaping
126+
algorithm that is passed to solve.
127+
"""
128+
struct LeapOnly <: AbstractAggregatorAlgorithm end
129+
export LeapOnly
130+
124131
include("simple_regular_solve.jl")
125132
export SimpleTauLeaping, EnsembleGPUKernel
126133

src/problem.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_e
6767
the DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage
6868
examples and commonly asked questions.
6969
"""
70-
mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J2,
71-
J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J}
70+
mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J1,
71+
J2, J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J}
7272
"""The type of problem to couple the jumps to. For a pure jump process use `DiscreteProblem`, to couple to ODEs, `ODEProblem`, etc."""
7373
prob::P
7474
"""The aggregator algorithm that determines the next jump times and types for `ConstantRateJump`s and `MassActionJump`s. Examples include `Direct`."""
@@ -77,6 +77,8 @@ mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggrega
7777
discrete_jump_aggregation::J
7878
"""`CallBackSet` with the underlying `ConstantRate` and `VariableRate` jumps."""
7979
jump_callback::C
80+
"""The `ConstantRateJump`s."""
81+
constant_jumps::J1
8082
"""The `VariableRateJump`s."""
8183
variable_jumps::J2
8284
"""The `RegularJump`s."""
@@ -88,10 +90,11 @@ mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggrega
8890
"""kwargs to pass on to solve call."""
8991
kwargs::K
9092
end
91-
function JumpProblem(p::P, a::A, dj::J, jc::C, vj::J2, rj::J3, mj::J4,
92-
rng::R, kwargs::K) where {P, A, J, C, J2, J3, J4, R, K}
93+
function JumpProblem(p::P, a::A, dj::J, jc::C, cj::J1, vj::J2, rj::J3, mj::J4,
94+
rng::R, kwargs::K) where {P, A, J, C, J1, J2, J3, J4, R, K}
9395
iip = isinplace_jump(p, rj)
94-
JumpProblem{iip, P, A, C, J, J2, J3, J4, R, K}(p, a, dj, jc, vj, rj, mj, rng, kwargs)
96+
JumpProblem{iip, P, A, C, J, J1, J2, J3, J4, R, K}(p, a, dj, jc, cj, vj, rj, mj,
97+
rng, kwargs)
9598
end
9699

97100
######## remaking ######
@@ -154,8 +157,8 @@ function DiffEqBase.remake(jprob::JumpProblem; kwargs...)
154157
end
155158

156159
T(newprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback,
157-
jprob.variable_jumps, jprob.regular_jump, jprob.massaction_jump, jprob.rng,
158-
jprob.kwargs)
160+
jprob.constant_jumps, jprob.variable_jumps, jprob.regular_jump,
161+
jprob.massaction_jump, jprob.rng, jprob.kwargs)
159162
end
160163

161164
# for updating parameters in JumpProblems to update MassActionJumps
@@ -253,6 +256,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
253256
end
254257

255258
ndiscjumps = get_num_majumps(maj) + num_crjs(jumps)
259+
crjs = jumps.constant_jumps
256260

257261
# separate bounded variable rate jumps *if* the aggregator can use them
258262
if use_vrj_bounds && supports_variablerates(aggregator) && (num_bndvrjs(jumps) > 0)
@@ -272,41 +276,35 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
272276
disc_agg = nothing
273277
constant_jump_callback = CallbackSet()
274278
else
275-
disc_agg = aggregate(aggregator, u, prob.p, t, end_time, jumps.constant_jumps, maj,
279+
disc_agg = aggregate(aggregator, u, prob.p, t, end_time, crjs, maj,
276280
save_positions, rng; kwargs...)
277281
constant_jump_callback = DiscreteCallback(disc_agg)
278282
end
279283

280284
# handle any remaining vrjs
281285
if length(cvrjs) > 0
282286
# Handle variable rate jumps based on vr_aggregator
283-
new_prob, variable_jump_callback,
284-
cont_agg = configure_jump_problem(prob,
285-
vr_aggregator, jumps, cvrjs; rng)
287+
new_prob, variable_jump_callback = configure_jump_problem(prob, vr_aggregator,
288+
jumps, cvrjs; rng)
286289
else
287290
new_prob = prob
288291
variable_jump_callback = CallbackSet()
289-
cont_agg = JumpSet().variable_jumps
292+
cvrjs = JumpSet().variable_jumps
290293
end
291294

292295
jump_cbs = CallbackSet(constant_jump_callback, variable_jump_callback)
293296
iip = isinplace_jump(prob, jumps.regular_jump)
294297
solkwargs = make_kwarg(; callback)
295298

296-
JumpProblem{iip, typeof(new_prob), typeof(aggregator),
297-
typeof(jump_cbs), typeof(disc_agg),
298-
typeof(cont_agg),
299-
typeof(jumps.regular_jump),
299+
JumpProblem{iip, typeof(new_prob), typeof(aggregator), typeof(jump_cbs),
300+
typeof(disc_agg), typeof(crjs), typeof(cvrjs), typeof(jumps.regular_jump),
300301
typeof(maj), typeof(rng), typeof(solkwargs)}(new_prob, aggregator, disc_agg,
301-
jump_cbs, cont_agg,
302-
jumps.regular_jump, maj, rng,
303-
solkwargs)
302+
jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, rng, solkwargs)
304303
end
305304

306-
aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A
305+
aggregator(jp::JumpProblem{iip, P, A}) where {iip, P, A} = A
307306

308-
@inline function extend_tstops!(tstops,
309-
jp::JumpProblem{P, A, C, J, J2}) where {P, A, C, J, J2}
307+
@inline function extend_tstops!(tstops, jp::JumpProblem)
310308
!(jp.jump_callback.discrete_callbacks isa Tuple{}) &&
311309
push!(tstops, jp.jump_callback.discrete_callbacks[1].condition.next_jump_time)
312310
end

src/variable_rate.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ function configure_jump_problem(prob, vr_aggregator::VR_FRM, jumps, cvrjs;
6262
rng = DEFAULT_RNG)
6363
new_prob = extend_problem(prob, cvrjs; rng)
6464
variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng)
65-
cont_agg = cvrjs
66-
return new_prob, variable_jump_callback, cont_agg
65+
return new_prob, variable_jump_callback
6766
end
6867

6968
# extends prob.u0 to an ExtendedJumpArray with Njumps integrated intensity values,
@@ -381,16 +380,14 @@ function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs; rng = DEFAULT_R
381380
new_prob = prob
382381
cache = VR_DirectEventCache(jumps, VR_Direct(), prob, eltype(prob.tspan); rng)
383382
variable_jump_callback = build_variable_integcallback(cache, cvrjs)
384-
cont_agg = cvrjs
385-
return new_prob, variable_jump_callback, cont_agg
383+
return new_prob, variable_jump_callback
386384
end
387385

388386
function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs; rng = DEFAULT_RNG)
389387
new_prob = prob
390388
cache = VR_DirectEventCache(jumps, VR_DirectFW(), prob, eltype(prob.tspan); rng)
391389
variable_jump_callback = build_variable_integcallback(cache, cvrjs)
392-
cont_agg = cvrjs
393-
return new_prob, variable_jump_callback, cont_agg
390+
return new_prob, variable_jump_callback
394391
end
395392

396393
# recursively evaluate the cumulative sum of the rates for type stability

0 commit comments

Comments
 (0)