Skip to content

Commit 4fd172d

Browse files
committed
PureLeaping aggregator
1 parent d0bd532 commit 4fd172d

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

src/JumpProcesses.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ include("SSA_stepper.jl")
121121
export SSAStepper
122122

123123
# leaping:
124+
"""
125+
Aggregator to indicate that individual jumps should also be handled via the leaping
126+
algorithm that is passed to solve.
127+
"""
128+
struct PureLeaping <: AbstractAggregatorAlgorithm end
129+
export PureLeaping
130+
124131
include("simple_regular_solve.jl")
125132
export SimpleTauLeaping, EnsembleGPUKernel
126133

src/problem.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,48 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
302302
jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, rng, solkwargs)
303303
end
304304

305+
# Special dispatch for PureLeaping aggregator - bypasses all aggregation
306+
function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet;
307+
save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ?
308+
(false, true) : (true, true),
309+
rng = DEFAULT_RNG, scale_rates = true, useiszero = true,
310+
spatial_system = nothing, hopping_constants = nothing,
311+
callback = nothing, kwargs...)
312+
313+
# Validate no spatial systems (not currently supported)
314+
(spatial_system !== nothing || hopping_constants !== nothing) &&
315+
error("PureLeaping does not currently support spatial problems.")
316+
317+
# Initialize the MassActionJump rate constants with the user parameters
318+
if using_params(jumps.massaction_jump)
319+
rates = jumps.massaction_jump.param_mapper(prob.p)
320+
maj = MassActionJump(rates, jumps.massaction_jump.reactant_stoch,
321+
jumps.massaction_jump.net_stoch,
322+
jumps.massaction_jump.param_mapper; scale_rates = scale_rates,
323+
useiszero = useiszero,
324+
nocopy = true)
325+
else
326+
maj = jumps.massaction_jump
327+
end
328+
329+
# For PureLeaping, all jumps are handled by the tau-leaping solver
330+
# No discrete jump aggregation or variable rate callbacks are created
331+
disc_agg = nothing
332+
jump_cbs = CallbackSet()
333+
334+
# Store all jump types for access by tau-leaping solver
335+
crjs = jumps.constant_jumps
336+
vrjs = jumps.variable_jumps
337+
338+
iip = isinplace_jump(prob, jumps.regular_jump)
339+
solkwargs = make_kwarg(; callback)
340+
341+
JumpProblem{iip, typeof(prob), typeof(aggregator), typeof(jump_cbs),
342+
typeof(disc_agg), typeof(crjs), typeof(vrjs), typeof(jumps.regular_jump),
343+
typeof(maj), typeof(rng), typeof(solkwargs)}(prob, aggregator, disc_agg,
344+
jump_cbs, crjs, vrjs, jumps.regular_jump, maj, rng, solkwargs)
345+
end
346+
305347
aggregator(jp::JumpProblem{iip, P, A}) where {iip, P, A} = A
306348

307349
@inline function extend_tstops!(tstops, jp::JumpProblem)

test/regular_jumps.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,87 @@ jumps = JumpSet(rj)
3939
prob = DiscreteProblem([999, 1, 0], (0.0, 250.0))
4040
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
4141
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
42+
43+
# Test PureLeaping aggregator functionality
44+
@testset "PureLeaping Aggregator Tests" begin
45+
# Test with MassActionJump
46+
u0 = [10, 5, 0]
47+
tspan = (0.0, 10.0)
48+
p = [0.1, 0.2]
49+
prob = DiscreteProblem(u0, p, tspan)
50+
51+
# Create MassActionJump
52+
reactant_stoich = [[1 => 1], [1 => 2]]
53+
net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]]
54+
rates = [0.1, 0.05]
55+
maj = MassActionJump(rates, reactant_stoich, net_stoich)
56+
57+
# Test PureLeaping JumpProblem creation
58+
jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj))
59+
@test jp_pure.aggregator isa PureLeaping
60+
@test jp_pure.discrete_jump_aggregation === nothing
61+
@test jp_pure.massaction_jump !== nothing
62+
@test length(jp_pure.jump_callback.discrete_callbacks) == 0
63+
64+
# Test with ConstantRateJump
65+
rate(u, p, t) = p[1] * u[1]
66+
affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
67+
crj = ConstantRateJump(rate, affect!)
68+
69+
jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj))
70+
@test jp_pure_crj.aggregator isa PureLeaping
71+
@test jp_pure_crj.discrete_jump_aggregation === nothing
72+
@test length(jp_pure_crj.constant_jumps) == 1
73+
74+
# Test with VariableRateJump
75+
vrate(u, p, t) = t * p[1] * u[1]
76+
vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
77+
vrj = VariableRateJump(vrate, vaffect!)
78+
79+
jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj))
80+
@test jp_pure_vrj.aggregator isa PureLeaping
81+
@test jp_pure_vrj.discrete_jump_aggregation === nothing
82+
@test length(jp_pure_vrj.variable_jumps) == 1
83+
84+
# Test with RegularJump
85+
function rj_rate(out, u, p, t)
86+
out[1] = p[1] * u[1]
87+
end
88+
89+
function rj_c(dc, u, p, t, mark)
90+
dc[1, 1] = -1
91+
dc[3, 1] = 1
92+
end
93+
94+
rj_dc = zeros(3, 1)
95+
regj = RegularJump(rj_rate, rj_c, rj_dc; constant_c = true)
96+
97+
jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj))
98+
@test jp_pure_regj.aggregator isa PureLeaping
99+
@test jp_pure_regj.discrete_jump_aggregation === nothing
100+
@test jp_pure_regj.regular_jump !== nothing
101+
102+
# Test mixed jump types
103+
mixed_jumps = JumpSet(maj, crj, vrj, regj)
104+
jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps)
105+
@test jp_pure_mixed.aggregator isa PureLeaping
106+
@test jp_pure_mixed.discrete_jump_aggregation === nothing
107+
@test jp_pure_mixed.massaction_jump !== nothing
108+
@test length(jp_pure_mixed.constant_jumps) == 1
109+
@test length(jp_pure_mixed.variable_jumps) == 1
110+
@test jp_pure_mixed.regular_jump !== nothing
111+
112+
# Test spatial system error
113+
spatial_sys = CartesianGrid((2, 2))
114+
hopping_consts = [1.0]
115+
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj);
116+
spatial_system = spatial_sys)
117+
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj);
118+
hopping_constants = hopping_consts)
119+
120+
# Test MassActionJump with parameter mapping
121+
param_mapper = MassActionJumpParamMapper([1, 2])
122+
maj_params = MassActionJump(reactant_stoich, net_stoich, param_mapper)
123+
jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params))
124+
@test jp_params.massaction_jump.scaled_rates == p
125+
end

0 commit comments

Comments
 (0)