Skip to content

Commit f2d3efe

Browse files
authored
Merge pull request #519 from isaacsas/cleanup_leaping_and_support_nonregjumps
Cleanup exports and add constant rate jumps to JumpProblem
2 parents 93a8917 + b4fc959 commit f2d3efe

File tree

5 files changed

+83
-69
lines changed

5 files changed

+83
-69
lines changed

src/JumpProcesses.jl

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -56,73 +56,94 @@ const USE_RSSA_THRESHOLD = 100
5656
const USE_SORTINGDIRECT_THRESHOLD = 200
5757

5858
include("jumps.jl")
59+
export ConstantRateJump, VariableRateJump, RegularJump, MassActionJump, JumpSet
60+
5961
include("massaction_rates.jl")
62+
63+
# constant rate aggregators (i.e. SSAs)
6064
include("aggregators/aggregators.jl")
65+
export get_num_majumps, needs_depgraph, needs_vartojumps_map, reset_aggregated_jumps!
66+
6167
include("aggregators/ssajump.jl")
68+
6269
include("aggregators/direct.jl")
70+
export Direct, DirectFW
71+
6372
include("aggregators/frm.jl")
73+
export FRM, FRMFW
74+
6475
include("aggregators/sortingdirect.jl")
76+
export SortingDirect
77+
6578
include("aggregators/nrm.jl")
79+
export NRM
80+
6681
include("aggregators/bracketing.jl")
82+
export BracketData
83+
6784
include("aggregators/rssa.jl")
85+
export RSSA
86+
6887
include("aggregators/prioritytable.jl")
88+
6989
include("aggregators/directcr.jl")
90+
export DirectCR
91+
7092
include("aggregators/rssacr.jl")
93+
export RSSACR
94+
7195
include("aggregators/rdirect.jl")
72-
include("aggregators/coevolve.jl")
73-
include("aggregators/ccnrm.jl")
96+
export RDirect
7497

75-
# spatial:
76-
include("spatial/spatial_massaction_jump.jl")
77-
include("spatial/topology.jl")
78-
include("spatial/hop_rates.jl")
79-
include("spatial/reaction_rates.jl")
80-
include("spatial/flatten.jl")
81-
include("spatial/utils.jl")
82-
include("spatial/bracketing.jl")
98+
include("aggregators/coevolve.jl")
99+
export Coevolve
83100

84-
include("spatial/nsm.jl")
85-
include("spatial/directcrdirect.jl")
101+
include("aggregators/ccnrm.jl")
102+
export CCNRM
86103

87104
include("aggregators/aggregated_api.jl")
88105

106+
# variable rate aggregators (i.e. SSAs)
89107
include("extended_jump_array.jl")
90-
include("variable_rate.jl")
91-
include("problem.jl")
92-
include("solve.jl")
93-
include("coupled_array.jl")
94-
include("coupling.jl")
95-
include("SSA_stepper.jl")
96-
include("simple_regular_solve.jl")
108+
export ExtendedJumpArray
97109

98-
export ConstantRateJump, VariableRateJump, RegularJump,
99-
MassActionJump, JumpSet
110+
include("variable_rate.jl")
111+
export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW
100112

101-
export JumpProblem
113+
# core problem and timestepping
114+
include("problem.jl")
115+
export JumpProblem, SplitCoupledJumpProblem
102116

103-
export SplitCoupledJumpProblem
117+
include("solve.jl")
118+
export init, solve, solve!
104119

105-
export Direct, DirectFW, SortingDirect, DirectCR
106-
export BracketData, RSSA
107-
export FRM, FRMFW, NRM, CCNRM
108-
export RSSACR, RDirect
109-
export Coevolve
120+
include("SSA_stepper.jl")
121+
export SSAStepper
110122

111-
export get_num_majumps, needs_depgraph, needs_vartojumps_map
123+
# leaping:
124+
include("simple_regular_solve.jl")
125+
export SimpleTauLeaping, EnsembleGPUKernel
112126

113-
export init, solve, solve!
127+
# spatial:
128+
include("spatial/spatial_massaction_jump.jl")
129+
export SpatialMassActionJump
114130

115-
export reset_aggregated_jumps!
131+
include("spatial/topology.jl")
132+
export CartesianGrid, CartesianGridRej, outdegree, num_sites, neighbors
116133

117-
export ExtendedJumpArray
134+
include("spatial/hop_rates.jl")
135+
include("spatial/reaction_rates.jl")
136+
include("spatial/flatten.jl")
137+
include("spatial/utils.jl")
138+
include("spatial/bracketing.jl")
139+
include("spatial/nsm.jl")
140+
export NSM
118141

119-
# Export VariableRateAggregator types
120-
export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW
142+
include("spatial/directcrdirect.jl")
143+
export DirectCRDirect
121144

122-
# spatial structs and functions
123-
export CartesianGrid, CartesianGridRej
124-
export SpatialMassActionJump
125-
export outdegree, num_sites, neighbors
126-
export NSM, DirectCRDirect
145+
# coupling
146+
include("coupled_array.jl")
147+
include("coupling.jl")
127148

128149
end # module

src/SSA_stepper.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,6 @@ function DiffEqBase.terminate!(integrator::SSAIntegrator, retcode = ReturnCode.T
381381
nothing
382382
end
383383

384-
export SSAStepper
385384

386385
function SciMLBase.isdenseplot(sol::ODESolution{
387386
T, N, uType, uType2, DType, tType, rateType, discType, P,

src/problem.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_e
6767
the DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage
6868
examples and commonly asked questions.
6969
"""
70-
mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J2,
71-
J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J}
70+
mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J1,
71+
J2, J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J}
7272
"""The type of problem to couple the jumps to. For a pure jump process use `DiscreteProblem`, to couple to ODEs, `ODEProblem`, etc."""
7373
prob::P
7474
"""The aggregator algorithm that determines the next jump times and types for `ConstantRateJump`s and `MassActionJump`s. Examples include `Direct`."""
@@ -77,6 +77,8 @@ mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggrega
7777
discrete_jump_aggregation::J
7878
"""`CallBackSet` with the underlying `ConstantRate` and `VariableRate` jumps."""
7979
jump_callback::C
80+
"""The `ConstantRateJump`s."""
81+
constant_jumps::J1
8082
"""The `VariableRateJump`s."""
8183
variable_jumps::J2
8284
"""The `RegularJump`s."""
@@ -88,10 +90,11 @@ mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggrega
8890
"""kwargs to pass on to solve call."""
8991
kwargs::K
9092
end
91-
function JumpProblem(p::P, a::A, dj::J, jc::C, vj::J2, rj::J3, mj::J4,
92-
rng::R, kwargs::K) where {P, A, J, C, J2, J3, J4, R, K}
93+
function JumpProblem(p::P, a::A, dj::J, jc::C, cj::J1, vj::J2, rj::J3, mj::J4,
94+
rng::R, kwargs::K) where {P, A, J, C, J1, J2, J3, J4, R, K}
9395
iip = isinplace_jump(p, rj)
94-
JumpProblem{iip, P, A, C, J, J2, J3, J4, R, K}(p, a, dj, jc, vj, rj, mj, rng, kwargs)
96+
JumpProblem{iip, P, A, C, J, J1, J2, J3, J4, R, K}(p, a, dj, jc, cj, vj, rj, mj,
97+
rng, kwargs)
9598
end
9699

97100
######## remaking ######
@@ -154,8 +157,8 @@ function DiffEqBase.remake(jprob::JumpProblem; kwargs...)
154157
end
155158

156159
T(newprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback,
157-
jprob.variable_jumps, jprob.regular_jump, jprob.massaction_jump, jprob.rng,
158-
jprob.kwargs)
160+
jprob.constant_jumps, jprob.variable_jumps, jprob.regular_jump,
161+
jprob.massaction_jump, jprob.rng, jprob.kwargs)
159162
end
160163

161164
# for updating parameters in JumpProblems to update MassActionJumps
@@ -253,6 +256,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
253256
end
254257

255258
ndiscjumps = get_num_majumps(maj) + num_crjs(jumps)
259+
crjs = jumps.constant_jumps
256260

257261
# separate bounded variable rate jumps *if* the aggregator can use them
258262
if use_vrj_bounds && supports_variablerates(aggregator) && (num_bndvrjs(jumps) > 0)
@@ -272,41 +276,35 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
272276
disc_agg = nothing
273277
constant_jump_callback = CallbackSet()
274278
else
275-
disc_agg = aggregate(aggregator, u, prob.p, t, end_time, jumps.constant_jumps, maj,
279+
disc_agg = aggregate(aggregator, u, prob.p, t, end_time, crjs, maj,
276280
save_positions, rng; kwargs...)
277281
constant_jump_callback = DiscreteCallback(disc_agg)
278282
end
279283

280284
# handle any remaining vrjs
281285
if length(cvrjs) > 0
282286
# Handle variable rate jumps based on vr_aggregator
283-
new_prob, variable_jump_callback,
284-
cont_agg = configure_jump_problem(prob,
285-
vr_aggregator, jumps, cvrjs; rng)
287+
new_prob, variable_jump_callback = configure_jump_problem(prob, vr_aggregator,
288+
jumps, cvrjs; rng)
286289
else
287290
new_prob = prob
288291
variable_jump_callback = CallbackSet()
289-
cont_agg = JumpSet().variable_jumps
292+
cvrjs = JumpSet().variable_jumps
290293
end
291294

292295
jump_cbs = CallbackSet(constant_jump_callback, variable_jump_callback)
293296
iip = isinplace_jump(prob, jumps.regular_jump)
294297
solkwargs = make_kwarg(; callback)
295298

296-
JumpProblem{iip, typeof(new_prob), typeof(aggregator),
297-
typeof(jump_cbs), typeof(disc_agg),
298-
typeof(cont_agg),
299-
typeof(jumps.regular_jump),
299+
JumpProblem{iip, typeof(new_prob), typeof(aggregator), typeof(jump_cbs),
300+
typeof(disc_agg), typeof(crjs), typeof(cvrjs), typeof(jumps.regular_jump),
300301
typeof(maj), typeof(rng), typeof(solkwargs)}(new_prob, aggregator, disc_agg,
301-
jump_cbs, cont_agg,
302-
jumps.regular_jump, maj, rng,
303-
solkwargs)
302+
jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, rng, solkwargs)
304303
end
305304

306-
aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A
305+
aggregator(jp::JumpProblem{iip, P, A}) where {iip, P, A} = A
307306

308-
@inline function extend_tstops!(tstops,
309-
jp::JumpProblem{P, A, C, J, J2}) where {P, A, C, J, J2}
307+
@inline function extend_tstops!(tstops, jp::JumpProblem)
310308
!(jp.jump_callback.discrete_callbacks isa Tuple{}) &&
311309
push!(tstops, jp.jump_callback.discrete_callbacks[1].condition.next_jump_time)
312310
end

src/simple_regular_solve.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,3 @@ function EnsembleGPUKernel()
6363
EnsembleGPUKernel(nothing, 0.0)
6464
end
6565

66-
export SimpleTauLeaping, EnsembleGPUKernel

src/variable_rate.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ function configure_jump_problem(prob, vr_aggregator::VR_FRM, jumps, cvrjs;
6262
rng = DEFAULT_RNG)
6363
new_prob = extend_problem(prob, cvrjs; rng)
6464
variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng)
65-
cont_agg = cvrjs
66-
return new_prob, variable_jump_callback, cont_agg
65+
return new_prob, variable_jump_callback
6766
end
6867

6968
# extends prob.u0 to an ExtendedJumpArray with Njumps integrated intensity values,
@@ -381,16 +380,14 @@ function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs; rng = DEFAULT_R
381380
new_prob = prob
382381
cache = VR_DirectEventCache(jumps, VR_Direct(), prob, eltype(prob.tspan); rng)
383382
variable_jump_callback = build_variable_integcallback(cache, cvrjs)
384-
cont_agg = cvrjs
385-
return new_prob, variable_jump_callback, cont_agg
383+
return new_prob, variable_jump_callback
386384
end
387385

388386
function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs; rng = DEFAULT_RNG)
389387
new_prob = prob
390388
cache = VR_DirectEventCache(jumps, VR_DirectFW(), prob, eltype(prob.tspan); rng)
391389
variable_jump_callback = build_variable_integcallback(cache, cvrjs)
392-
cont_agg = cvrjs
393-
return new_prob, variable_jump_callback, cont_agg
390+
return new_prob, variable_jump_callback
394391
end
395392

396393
# recursively evaluate the cumulative sum of the rates for type stability

0 commit comments

Comments
 (0)