Skip to content

Commit 637be25

Browse files
committed
switch MassActionJumps to use parameters
1 parent 2e3e7f2 commit 637be25

File tree

3 files changed

+86
-32
lines changed

3 files changed

+86
-32
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -186,23 +186,23 @@ function numericnstoich(mtrs::Vector{Pair{V,W}}, statetoid) where {V,W}
186186
end
187187

188188
# assemble a numeric MassActionJump from a MT MassActionJump representing one rx.
189-
function assemble_maj(maj::MassActionJump, statetoid, subber, invttype)
190-
rval = subber(maj.scaled_rates)
191-
rs = numericrstoich(maj.reactant_stoch, statetoid)
192-
ns = numericnstoich(maj.net_stoch, statetoid)
193-
maj = MassActionJump(convert(invttype, value(rval)), rs, ns, scale_rates = false)
194-
maj
195-
end
196-
197-
# For MassActionJumps that contain many reactions
198-
# function assemble_maj(maj::MassActionJump{U,V,W}, statetoid, subber,
199-
# invttype) where {U <: AbstractVector,V,W}
200-
# rval = [convert(invttype,numericrate(sr, subber)) for sr in maj.scaled_rates]
201-
# rs = [numericrstoich(rs, statetoid) for rs in maj.reactant_stoch]
202-
# ns = [numericnstoich(ns, statetoid) for ns in maj.net_stoch]
203-
# maj = MassActionJump(rval, rs, ns, scale_rates = false)
189+
# function assemble_maj(maj::MassActionJump, statetoid, subber, invttype)
190+
# rval = subber(maj.scaled_rates)
191+
# rs = numericrstoich(maj.reactant_stoch, statetoid)
192+
# ns = numericnstoich(maj.net_stoch, statetoid)
193+
# maj = MassActionJump(convert(invttype, value(rval)), rs, ns, scale_rates = false)
204194
# maj
205195
# end
196+
197+
rstype(::MassActionJump{U,Vector{Pair{V,W}},X,Y}) where {U,V,W,X,Y} = W
198+
199+
# assemble a numeric MassActionJump from a MT symbolics MassActionJumps
200+
function assemble_maj(majv::Vector{U}, statetoid, pmapper, params) where {U <: MassActionJump}
201+
rs = [numericrstoich(maj.reactant_stoch, statetoid) for maj in majv]
202+
ns = [numericnstoich(maj.net_stoch, statetoid) for maj in majv]
203+
MassActionJump(rs, ns; param_mapper = pmapper, params=params, scale_rates=false, nocopy=true)
204+
end
205+
206206
"""
207207
```julia
208208
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
@@ -287,14 +287,13 @@ function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
287287

288288
# handling parameter substition and empty param vecs
289289
p = (prob.p isa DiffEqBase.NullParameters || prob.p === nothing) ? Num[] : prob.p
290-
parammap = map((x,y)->Pair(x,y), parameters(js), p)
291-
subber = substituter(parammap)
292290

293-
majs = MassActionJump[assemble_maj(j, statetoid, subber, invttype) for j in eqs.x[1]]
291+
majpmapper = JumpSysMajParamMapper(js, p; jseqs=eqs, rateconsttype=invttype)
292+
majs = isempty(eqs.x[1]) ? nothing : assemble_maj(eqs.x[1], statetoid, majpmapper, p)
294293
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
295294
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
296295
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
297-
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, isempty(majs) ? nothing : majs)
296+
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)
298297

299298
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator)
300299
jdeps = asgraph(js)
@@ -338,3 +337,44 @@ function modified_states!(mstates, jump::MassActionJump, sts)
338337
any(isequal(state), sts) && push!(mstates, state)
339338
end
340339
end
340+
341+
342+
343+
###################### parameter mapper ###########################
344+
struct JumpSysMajParamMapper{U,V,W}
345+
paramexprs::U # the parameter expressions to use for each jump rate constant
346+
sympars::V # parameters(sys) from the underlying JumpSystem
347+
subdict # mapping from an element of parameters(sys) to its current numerical value
348+
end
349+
350+
function JumpSysMajParamMapper(js::JumpSystem, p; jseqs=nothing, rateconsttype=Float64)
351+
eqs = (jseqs === nothing) ? equations(js) : jseqs
352+
psyms = parameters(js)
353+
parammap = map((x,y)->Pair(x,y), psyms, p)
354+
paramdict = Dict(value(k) => value(v) for (k, v) in parammap)
355+
paramexprs = [maj.scaled_rates for maj in eqs.x[1]]
356+
JumpSysMajParamMapper{typeof(paramexprs),typeof(psyms),rateconsttype}(paramexprs, psyms, paramdict)
357+
end
358+
359+
function updateparams!(ratemap::JumpSysMajParamMapper{U,V,W}, params) where {U <: AbstractArray, V <: AbstractArray, W}
360+
for (i,p) in enumerate(params)
361+
sympar = ratemap.sympars[i]
362+
ratemap.subdict[sympar] = p
363+
end
364+
end
365+
366+
# create the initial parameter vector for use in a MassActionJump
367+
function (ratemap::JumpSysMajParamMapper{U,V,W})(params) where {U <: AbstractArray, V <: AbstractArray, W}
368+
updateparams!(ratemap, params)
369+
[convert(W,value(substitute(paramexpr, ratemap.subdict))) for paramexpr in ratemap.paramexprs]
370+
end
371+
372+
# update a maj with parameter vectors
373+
function (ratemap::JumpSysMajParamMapper{U,V,W})(maj::MassActionJump, newparams; scale_rates, kwargs...) where {U <: AbstractArray, V <: AbstractArray, W}
374+
updateparams!(ratemap, newparams)
375+
for i in 1:get_num_majumps(maj)
376+
maj.scaled_rates[i] = convert(W,value(substitute(ratemap.paramexprs[i], ratemap.subdict)))
377+
end
378+
scale_rates && DiffEqJump.scalerates!(maj.scaled_rates, maj.reactant_stoch)
379+
nothing
380+
end

test/jumpsystem.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ m2 = getmean(jprob,Nsims)
116116
maj1 = MassActionJump(2*β/2, [S => 1, I => 1], [S => -1, I => 1])
117117
maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
118118
js3 = JumpSystem([maj1,maj2], t, [S,I,R], [β,γ])
119-
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
120-
ptoid = Dict(MT.value(par) => i for (i,par) in enumerate(parameters(js)))
121119
dprob = DiscreteProblem(js3, u₀map, tspan, parammap)
122120
jprob = JumpProblem(js3, dprob, Direct())
123121
m3 = getmean(jprob,Nsims)
@@ -136,8 +134,6 @@ m4 = getmean(jprobc,Nsims)
136134
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
137135
maj2 = MassActionJump(γ, [S => 1], [S => -1])
138136
js4 = JumpSystem([maj1,maj2], t, [S], [β,γ])
139-
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
140-
ptoid = Dict(MT.value(par) => i for (i,par) in enumerate(parameters(js)))
141137
dprob = DiscreteProblem(js4, [S => 999], (0,1000.), [β => 100.=> .01])
142138
jprob = JumpProblem(js4, dprob, Direct())
143139
m4 = getmean(jprob,Nsims)
@@ -147,8 +143,6 @@ m4 = getmean(jprob,Nsims)
147143
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
148144
maj2 = MassActionJump(γ, [S => 2], [S => -1])
149145
js4 = JumpSystem([maj1,maj2], t, [S], [β,γ])
150-
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
151-
ptoid = Dict(MT.value(par) => i for (i,par) in enumerate(parameters(js)))
152146
dprob = DiscreteProblem(js4, [S => 999], (0,1000.), [β => 100.=> .01])
153147
jprob = JumpProblem(js4, dprob, Direct())
154148
sol = solve(jprob, SSAStepper());
@@ -158,4 +152,24 @@ sol = solve(jprob, SSAStepper());
158152
sys1 = JumpSystem([maj1, maj2], t, [S], [β, γ], name = :sys1)
159153
sys2 = JumpSystem([maj1, maj2], t, [S], [β, γ], name = :sys1)
160154
@test_throws ArgumentError JumpSystem([sys1.γ ~ sys2.γ], t, [], [], systems = [sys1, sys2])
161-
end
155+
end
156+
157+
# test if param mapper is setup correctly for callbacks
158+
@parameters k1 k2 k3
159+
@variables A(t) B(t)
160+
maj1 = MassActionJump(k1*k3, [0 => 1], [A => -1, B => 1])
161+
maj2 = MassActionJump(k2, [B => 1], [A => 1, B => -1])
162+
js5 = JumpSystem([maj1,maj2], t, [A,B], [k1,k2,k3])
163+
p = [k1 => 2.0, k2 => 0.0, k3 => .5]
164+
u₀ = [A => 100, B => 0]
165+
tspan = (0.0,2000.0)
166+
dprob = DiscreteProblem(js5, u₀, tspan, p)
167+
jprob = JumpProblem(js5, dprob, Direct(), save_positions=(false,false))
168+
pcondit(u,t,integrator) = t==1000.0
169+
function paffect!(integrator)
170+
integrator.p[1] = 0.0
171+
integrator.p[2] = 1.0
172+
reset_aggregated_jumps!(integrator)
173+
end
174+
sol = solve(jprob, SSAStepper(), tstops=[1000.0], callback=DiscreteCallback(pcondit,paffect!))
175+
@test sol[1,end] == 100

test/reactionsystem.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,12 @@ jumps[19] = VariableRateJump((u,p,t) -> p[19]*u[1]*t, integrator -> (integrator.
192192
jumps[20] = VariableRateJump((u,p,t) -> p[20]*t*u[1]*binomial(u[2],2)*u[3], integrator -> (integrator.u[2] -= 2; integrator.u[3] -= 1; integrator.u[4] += 2))
193193

194194
statetoid = Dict(state => i for (i,state) in enumerate(states(js)))
195-
parammap = map((x,y)->Pair(x,y),parameters(js),pars)
196-
for i in midxs
197-
maj = MT.assemble_maj(equations(js)[i], statetoid, ModelingToolkit.substituter(parammap),eltype(pars))
198-
@test abs(jumps[i].scaled_rates - maj.scaled_rates) < 100*eps()
199-
@test jumps[i].reactant_stoch == maj.reactant_stoch
200-
@test jumps[i].net_stoch == maj.net_stoch
195+
jspmapper = ModelingToolkit.JumpSysMajParamMapper(js, pars)
196+
maj = MT.assemble_maj(equations(js).x[1], statetoid, jspmapper, pars)
197+
for i in midxs
198+
@test abs(jumps[i].scaled_rates - maj.scaled_rates[i]) < 100*eps()
199+
@test jumps[i].reactant_stoch == maj.reactant_stoch[i]
200+
@test jumps[i].net_stoch == maj.net_stoch[i]
201201
end
202202
for i in cidxs
203203
crj = MT.assemble_crj(js, equations(js)[i], statetoid)

0 commit comments

Comments
 (0)