diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 0794d48f2..12cd7052a 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -50,6 +50,7 @@ include("aggregators/prioritytable.jl") include("aggregators/directcr.jl") include("aggregators/rssacr.jl") include("aggregators/rdirect.jl") +include("aggregators/extrande.jl") include("aggregators/coevolve.jl") # spatial: @@ -84,6 +85,7 @@ export Direct, DirectFW, SortingDirect, DirectCR export BracketData, RSSA export FRM, FRMFW, NRM export RSSACR, RDirect +export Extrande export Coevolve export get_num_majumps, needs_depgraph, needs_vartojumps_map diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index 0ec2e28ae..ed046e57d 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -157,8 +157,18 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108 """ struct DirectCRDirect <: AbstractAggregatorAlgorithm end +""" +The Extrande method for simulating variable rate jumps with user-defined bounds +on jumps rates and validity intervals via rejection. + +Stochastic Simulation of Biomolecular Networks in Dynamic Environments, Voliotis +M, Thomas P, Grima R, Bowsher CG, PLOS Computational Biology 12(6): e1004923. +(2016); doi.org/10.1371/journal.pcbi.1004923 +""" +struct Extrande <: AbstractAggregatorAlgorithm end + const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), - FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve()) + FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), Extrande()) # For JumpProblem construction without an aggregator struct NullAggregator <: AbstractAggregatorAlgorithm end @@ -181,6 +191,7 @@ needs_vartojumps_map(aggregator::RSSACR) = true # true if aggregator supports variable rates supports_variablerates(aggregator::AbstractAggregatorAlgorithm) = false supports_variablerates(aggregator::Coevolve) = true +supports_variablerates(aggregator::Extrande) = true is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true diff --git a/src/aggregators/extrande.jl b/src/aggregators/extrande.jl new file mode 100644 index 000000000..c59155403 --- /dev/null +++ b/src/aggregators/extrande.jl @@ -0,0 +1,122 @@ +# Define the aggregator. +struct Extrande <: AbstractAggregatorAlgorithm end + +""" +Extrande sampling method for jumps with defined rate bounds. +""" + +nullaffect!(integrator) = nothing +const NullAffectJump = ConstantRateJump((u, p, t) -> 0.0, nullaffect!) + +mutable struct ExtrandeJumpAggregation{T, S, F1, F2, F3, F4, RNG} <: + AbstractSSAJumpAggregator + next_jump::Int + prev_jump::Int + next_jump_time::T + end_time::T + cur_rates::Vector{T} + sum_rate::T + ma_jumps::S + rate_bnds::F3 + wds::F4 + rates::F1 + affects!::F2 + save_positions::Tuple{Bool, Bool} + rng::RNG +end + +function ExtrandeJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, + rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; + rate_bounds::F3, windows::F4, + kwargs...) where {T, S, F1, F2, F3, F4, RNG} + ExtrandeJumpAggregation{T, S, F1, F2, F3, F4, RNG}(nj, nj, njt, et, crs, sr, maj, + rate_bounds, windows, rs, affs!, sps, + rng) +end + +############################# Required Functions ############################## +function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; variable_jumps = (), kwargs...) + rates, affects! = get_jump_info_fwrappers(u, p, t, + (constant_jumps..., variable_jumps..., + NullAffectJump)) + rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t, + (constant_jumps..., variable_jumps..., + NullAffectJump)) + build_jump_aggregation(ExtrandeJumpAggregation, u, p, t, end_time, ma_jumps, + rates, affects!, save_positions, rng; u = u, rate_bounds = rbnds, + windows = wnds, kwargs...) +end + +# set up a new simulation and calculate the first jump / jump time +function initialize!(p::ExtrandeJumpAggregation, integrator, u, params, t) + p.end_time = integrator.sol.prob.tspan[2] + generate_jumps!(p, integrator, u, params, t) +end + +# execute one jump, changing the system state +@inline function execute_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t) + # execute jump + u = update_state!(p, integrator, u) + nothing +end + +@fastmath function next_extrande_jump(p::ExtrandeJumpAggregation, u, params, t) + ttnj = typemax(typeof(t)) + Wmin = typemax(typeof(t)) + Bmax = zero(t) + + prev_rate = zero(t) + new_rate = zero(t) + cur_rates = p.cur_rates + + # Mass action rates + majumps = p.ma_jumps + idx = get_num_majumps(majumps) + + @inbounds for i in 1:idx + new_rate = evalrxrate(u, i, majumps) + cur_rates[i] = add_fast(new_rate, prev_rate) + prev_rate = cur_rates[i] + Bmax += prev_rate + end + + # Calculate the total rate bound and the largest common validity window. + if !isempty(p.rate_bnds) + @inbounds for i in 1:length(p.wds) + Wmin = min(Wmin, p.wds[i](u, params, t)) + Bmax += p.rate_bnds[i](u, params, t) + end + end + + # Rejection sampling. + nextrx = length(cur_rates) + prop_ttnj = randexp(p.rng) / Bmax + if prop_ttnj < Wmin + if !isempty(p.rates) + idx += 1 + fill_cur_rates(u, params, prop_ttnj + t, p.cur_rates, idx, p.rates...) + @inbounds for i in idx:length(cur_rates) + cur_rates[i] = add_fast(cur_rates[i], prev_rate) + prev_rate = cur_rates[i] + end + end + UBmax = rand(p.rng) * Bmax + ttnj = prop_ttnj + if p.cur_rates[end] ≥ UBmax + nextrx = searchsortedfirst(p.cur_rates, UBmax) + end + else + ttnj = Wmin + end + + return nextrx, ttnj +end + +function generate_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t) + nextexj, ttnexj = next_extrande_jump(p, u, params, t) + p.next_jump = nextexj + p.next_jump_time = t + ttnexj + + nothing +end diff --git a/src/jumps.jl b/src/jumps.jl index c51790f5c..b8b2ec825 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -702,3 +702,27 @@ function get_jump_info_fwrappers(u, p, t, constant_jumps) rates, affects! end + +##### helpers for splitting variable rate jumps with rate bounds and without ##### + +function rate_window_function(jump) + # Assumes that if no window is given the rate bound is valid for all times. + return !(jump.rateinterval isa Nothing) ? jump.rateinterval : (u, p, t) -> Inf +end + +function get_va_jump_bound_info_fwrapper(u, p, t, jumps) + RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), + Tuple{typeof(u), typeof(p), typeof(t)}} + + if (jumps !== nothing) && !isempty(jumps) + rates = [j isa VariableRateJump ? RateWrapper(j.urate) : RateWrapper(j.rate) + for j in jumps] + wnds = [j isa VariableRateJump ? RateWrapper(rate_window_function(j)) : + RateWrapper((u, p, t) -> Inf) for j in jumps] + else + rates = Vector{RateWrapper}() + wnds = Vector{RateWrapper}() + end + + rates, wnds +end diff --git a/test/degenerate_rx_cases.jl b/test/degenerate_rx_cases.jl index b81bb2b34..79e9fb9cf 100644 --- a/test/degenerate_rx_cases.jl +++ b/test/degenerate_rx_cases.jl @@ -13,7 +13,7 @@ doprint = false doplot = false methods = (RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), - NRM(), RSSA(), DirectCR(), Coevolve()) + NRM(), RSSA(), DirectCR(), Coevolve(), Extrande()) # one reaction case, mass action jump, vector of data rate = [2.0] diff --git a/test/extrande.jl b/test/extrande.jl new file mode 100644 index 000000000..444ed1ed1 --- /dev/null +++ b/test/extrande.jl @@ -0,0 +1,74 @@ +using DiffEqBase, JumpProcesses, OrdinaryDiffEq, Test +using StableRNGs +using Statistics +rng = StableRNG(48572) + +f = function (du, u, p, t) + du[1] = 0.0 +end + +rate = (u, p, t) -> t < 5.0 ? 1.0 : 0.0 +rbound = (u, p, t) -> 1.0 +rinterval = (u, p, t) -> Inf +affect! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1) +jump = VariableRateJump(rate, affect!; urate = rbound, rateinterval = rinterval) + +prob = ODEProblem(f, [0.0], (0.0, 10.0)) +jump_prob = JumpProblem(prob, Extrande(), jump; rng = rng) + +# Test that process doesn't jump when rate switches to 0. +sol = solve(jump_prob, Tsit5()) +@test sol(5.0)[1] == sol[end][1] + +# Birth-death process with time-varying birth rates. +Nsims = 1000000 +u0 = [10.0] + +function runsimulations(jump_prob, testts) + Psamp = zeros(Int, length(testts), Nsims) + for i in 1:Nsims + sol_ = solve(jump_prob, Tsit5()) + Psamp[:, i] = getindex.(sol_(testts).u, 1) + end + mean(Psamp, dims = 2) +end + +# Variable rate birth jumps. +rateb = (u, p, t) -> (0.1 * sin(t) + 0.2) +ratebbound = (u, p, t) -> 0.3 +ratebwindow = (u, p, t) -> Inf +affectb! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1) +jumpb = VariableRateJump(rateb, affectb!; urate = ratebbound, rateinterval = ratebwindow) + +# Constant rate death jumps. +rated = (u, p, t) -> u[1] * 0.08 +affectd! = (integrator) -> (integrator.u[1] = integrator.u[1] - 1) +jumpd = ConstantRateJump(rated, affectd!) + +# Problem definition. +bd_prob = ODEProblem(f, u0, (0.0, 2pi)) +jump_bd_prob = JumpProblem(bd_prob, Extrande(), jumpb, jumpd) + +test_times = range(1.0, stop = 2pi, length = 3) +means = runsimulations(jump_bd_prob, test_times) + +# ODE for the mean. +fu = function (du, u, p, t) + du[1] = (0.1 * sin(t) + 0.2) - (u[1] * 0.08) +end + +ode_prob = ODEProblem(fu, u0, (0.0, 2 * pi)) +ode_sol = solve(ode_prob, Tsit5()) + +# Test extrande against the ODE mean. +@test prod(isapprox.(means, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3)) + +# Make sure interfaces correctly with Mass Action Jumps. +reactant_stoich = [[1 => 1]] +net_stoich = [[1 => -1]] +majd = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1]) +bmajd_prob = ODEProblem(f, u0, (0.0, 2pi), [0.08]) +jump_bmajd_prob = JumpProblem(bmajd_prob, Extrande(), jumpb, majd) + +means_mass_action = runsimulations(jump_bmajd_prob, test_times) +@test prod(isapprox.(means_mass_action, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3)) diff --git a/test/hawkes_test.jl b/test/hawkes_test.jl index 0de428e36..b2801705f 100644 --- a/test/hawkes_test.jl +++ b/test/hawkes_test.jl @@ -105,7 +105,7 @@ h = [Float64[]] Eλ, Varλ = expected_stats_hawkes_problem(p, tspan) -algs = (Direct(), Coevolve(), Coevolve()) +algs = (Direct(), Coevolve(), Coevolve(), Extrande()) uselrate = zeros(Bool, length(algs)) uselrate[3] = true Nsims = 250 @@ -122,7 +122,7 @@ for (i, alg) in enumerate(algs) reset_history!(h) sols[n] = solve(jump_prob, stepper) end - if typeof(alg) <: Coevolve + if typeof(alg) <: Union{Coevolve, Extrande} λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols)) else cols = length(sols[1].u[1].u) diff --git a/test/linearreaction_test.jl b/test/linearreaction_test.jl index d169b5713..4bc6c5b63 100644 --- a/test/linearreaction_test.jl +++ b/test/linearreaction_test.jl @@ -16,7 +16,7 @@ tf = 0.1 baserate = 0.1 A0 = 100 exactmean = (t, ratevec) -> A0 * exp(-sum(ratevec) * t) -SSAalgs = [RSSACR(), Direct(), RSSA()] +SSAalgs = [RSSACR(), Direct(), RSSA(), Extrande()] spec_to_dep_jumps = [collect(1:Nrxs), []] jump_to_dep_specs = [[1, 2] for i in 1:Nrxs] diff --git a/test/runtests.jl b/test/runtests.jl index 1b9a7b210..d6587e271 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,4 +28,5 @@ using JumpProcesses, DiffEqBase, SafeTestsets @time @safetestset "Spatial A + B <--> C" begin include("spatial/ABC.jl") end @time @safetestset "Spatially Varying Reaction Rates" begin include("spatial/spatial_majump.jl") end @time @safetestset "Pure diffusion" begin include("spatial/diffusion.jl") end + @time @safetestset "Ficticious Jump " begin include("extrande.jl") end end