Skip to content

Commit 471ddc1

Browse files
Merge pull request #1143 from isaacsas/pmapper-for-jumpsystems
switch MassActionJumps to use parameter vectors
2 parents c294a7d + 14d6f0a commit 471ddc1

File tree

4 files changed

+87
-35
lines changed

4 files changed

+87
-35
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ ConstructionBase = "1"
5050
DataStructures = "0.17, 0.18"
5151
DiffEqBase = "6.54.0"
5252
DiffEqCallbacks = "2.16"
53-
DiffEqJump = "6.7.5"
53+
DiffEqJump = "7.0"
5454
DiffRules = "0.1, 1.0"
5555
Distributions = "0.23, 0.24, 0.25"
5656
DocStringExtensions = "0.7, 0.8"

src/systems/jumps/jumpsystem.jl

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -185,24 +185,13 @@ function numericnstoich(mtrs::Vector{Pair{V,W}}, statetoid) where {V,W}
185185
sort!(ns)
186186
end
187187

188-
# assemble a numeric MassActionJump from a MT MassActionJump representing one rx.
189-
function assemble_maj(maj::MassActionJump, statetoid, subber, invttype)
190-
rval = subber(maj.scaled_rates)
191-
rs = numericrstoich(maj.reactant_stoch, statetoid)
192-
ns = numericnstoich(maj.net_stoch, statetoid)
193-
maj = MassActionJump(convert(invttype, value(rval)), rs, ns, scale_rates = false)
194-
maj
188+
# assemble a numeric MassActionJump from a MT symbolics MassActionJumps
189+
function assemble_maj(majv::Vector{U}, statetoid, pmapper) where {U <: MassActionJump}
190+
rs = [numericrstoich(maj.reactant_stoch, statetoid) for maj in majv]
191+
ns = [numericnstoich(maj.net_stoch, statetoid) for maj in majv]
192+
MassActionJump(rs, ns; param_mapper=pmapper, nocopy=true)
195193
end
196194

197-
# For MassActionJumps that contain many reactions
198-
# function assemble_maj(maj::MassActionJump{U,V,W}, statetoid, subber,
199-
# invttype) where {U <: AbstractVector,V,W}
200-
# rval = [convert(invttype,numericrate(sr, subber)) for sr in maj.scaled_rates]
201-
# rs = [numericrstoich(rs, statetoid) for rs in maj.reactant_stoch]
202-
# ns = [numericnstoich(ns, statetoid) for ns in maj.net_stoch]
203-
# maj = MassActionJump(rval, rs, ns, scale_rates = false)
204-
# maj
205-
# end
206195
"""
207196
```julia
208197
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
@@ -287,14 +276,13 @@ function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
287276

288277
# handling parameter substition and empty param vecs
289278
p = (prob.p isa DiffEqBase.NullParameters || prob.p === nothing) ? Num[] : prob.p
290-
parammap = map((x,y)->Pair(x,y), parameters(js), p)
291-
subber = substituter(parammap)
292279

293-
majs = MassActionJump[assemble_maj(j, statetoid, subber, invttype) for j in eqs.x[1]]
280+
majpmapper = JumpSysMajParamMapper(js, p; jseqs=eqs, rateconsttype=invttype)
281+
majs = isempty(eqs.x[1]) ? nothing : assemble_maj(eqs.x[1], statetoid, majpmapper)
294282
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
295283
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
296284
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
297-
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, isempty(majs) ? nothing : majs)
285+
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)
298286

299287
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator)
300288
jdeps = asgraph(js)
@@ -306,7 +294,8 @@ function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
306294
vtoj = nothing; jtov = nothing; jtoj = nothing
307295
end
308296

309-
JumpProblem(prob, aggregator, jset; dep_graph=jtoj, vartojumps_map=vtoj, jumptovars_map=jtov, kwargs...)
297+
JumpProblem(prob, aggregator, jset; dep_graph=jtoj, vartojumps_map=vtoj, jumptovars_map=jtov,
298+
scale_rates=false, nocopy=true, kwargs...)
310299
end
311300

312301

@@ -338,3 +327,49 @@ function modified_states!(mstates, jump::MassActionJump, sts)
338327
any(isequal(state), sts) && push!(mstates, state)
339328
end
340329
end
330+
331+
332+
333+
###################### parameter mapper ###########################
334+
struct JumpSysMajParamMapper{U,V,W}
335+
paramexprs::U # the parameter expressions to use for each jump rate constant
336+
sympars::V # parameters(sys) from the underlying JumpSystem
337+
subdict # mapping from an element of parameters(sys) to its current numerical value
338+
end
339+
340+
function JumpSysMajParamMapper(js::JumpSystem, p; jseqs=nothing, rateconsttype=Float64)
341+
eqs = (jseqs === nothing) ? equations(js) : jseqs
342+
paramexprs = [maj.scaled_rates for maj in eqs.x[1]]
343+
psyms = parameters(js)
344+
paramdict = Dict(value(k) => value(v) for (k, v) in zip(psyms,p))
345+
JumpSysMajParamMapper{typeof(paramexprs),typeof(psyms),rateconsttype}(paramexprs, psyms, paramdict)
346+
end
347+
348+
function updateparams!(ratemap::JumpSysMajParamMapper{U,V,W}, params) where {U <: AbstractArray, V <: AbstractArray, W}
349+
for (i,p) in enumerate(params)
350+
sympar = ratemap.sympars[i]
351+
ratemap.subdict[sympar] = p
352+
end
353+
nothing
354+
end
355+
356+
function updateparams!(::JumpSysMajParamMapper{U,V,W}, params::Nothing) where {U <: AbstractArray, V <: AbstractArray, W}
357+
nothing
358+
end
359+
360+
361+
# create the initial parameter vector for use in a MassActionJump
362+
function (ratemap::JumpSysMajParamMapper{U,V,W})(params) where {U <: AbstractArray, V <: AbstractArray, W}
363+
updateparams!(ratemap, params)
364+
[convert(W,value(substitute(paramexpr, ratemap.subdict))) for paramexpr in ratemap.paramexprs]
365+
end
366+
367+
# update a maj with parameter vectors
368+
function (ratemap::JumpSysMajParamMapper{U,V,W})(maj::MassActionJump, newparams; scale_rates, kwargs...) where {U <: AbstractArray, V <: AbstractArray, W}
369+
updateparams!(ratemap, newparams)
370+
for i in 1:get_num_majumps(maj)
371+
maj.scaled_rates[i] = convert(W,value(substitute(ratemap.paramexprs[i], ratemap.subdict)))
372+
end
373+
scale_rates && DiffEqJump.scalerates!(maj.scaled_rates, maj.reactant_stoch)
374+
nothing
375+
end

test/jumpsystem.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ m2 = getmean(jprob,Nsims)
116116
maj1 = MassActionJump(2*β/2, [S => 1, I => 1], [S => -1, I => 1])
117117
maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
118118
js3 = JumpSystem([maj1,maj2], t, [S,I,R], [β,γ])
119-
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
120-
ptoid = Dict(MT.value(par) => i for (i,par) in enumerate(parameters(js)))
121119
dprob = DiscreteProblem(js3, u₀map, tspan, parammap)
122120
jprob = JumpProblem(js3, dprob, Direct())
123121
m3 = getmean(jprob,Nsims)
@@ -136,8 +134,6 @@ m4 = getmean(jprobc,Nsims)
136134
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
137135
maj2 = MassActionJump(γ, [S => 1], [S => -1])
138136
js4 = JumpSystem([maj1,maj2], t, [S], [β,γ])
139-
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
140-
ptoid = Dict(MT.value(par) => i for (i,par) in enumerate(parameters(js)))
141137
dprob = DiscreteProblem(js4, [S => 999], (0,1000.), [β => 100.=> .01])
142138
jprob = JumpProblem(js4, dprob, Direct())
143139
m4 = getmean(jprob,Nsims)
@@ -147,8 +143,6 @@ m4 = getmean(jprob,Nsims)
147143
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
148144
maj2 = MassActionJump(γ, [S => 2], [S => -1])
149145
js4 = JumpSystem([maj1,maj2], t, [S], [β,γ])
150-
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
151-
ptoid = Dict(MT.value(par) => i for (i,par) in enumerate(parameters(js)))
152146
dprob = DiscreteProblem(js4, [S => 999], (0,1000.), [β => 100.=> .01])
153147
jprob = JumpProblem(js4, dprob, Direct())
154148
sol = solve(jprob, SSAStepper());
@@ -158,4 +152,26 @@ sol = solve(jprob, SSAStepper());
158152
sys1 = JumpSystem([maj1, maj2], t, [S], [β, γ], name = :sys1)
159153
sys2 = JumpSystem([maj1, maj2], t, [S], [β, γ], name = :sys1)
160154
@test_throws ArgumentError JumpSystem([sys1.γ ~ sys2.γ], t, [], [], systems = [sys1, sys2])
161-
end
155+
end
156+
157+
# test if param mapper is setup correctly for callbacks
158+
@parameters k1 k2 k3
159+
@variables A(t) B(t)
160+
maj1 = MassActionJump(k1*k3, [0 => 1], [A => -1, B => 1])
161+
maj2 = MassActionJump(k2, [B => 1], [A => 1, B => -1])
162+
js5 = JumpSystem([maj1,maj2], t, [A,B], [k1,k2,k3])
163+
p = [k1 => 2.0, k2 => 0.0, k3 => .5]
164+
u₀ = [A => 100, B => 0]
165+
tspan = (0.0,2000.0)
166+
dprob = DiscreteProblem(js5, u₀, tspan, p)
167+
jprob = JumpProblem(js5, dprob, Direct(), save_positions=(false,false))
168+
@test all(jprob.massaction_jump.scaled_rates .== [1.0,0.0])
169+
170+
pcondit(u,t,integrator) = t==1000.0
171+
function paffect!(integrator)
172+
integrator.p[1] = 0.0
173+
integrator.p[2] = 1.0
174+
reset_aggregated_jumps!(integrator)
175+
end
176+
sol = solve(jprob, SSAStepper(), tstops=[1000.0], callback=DiscreteCallback(pcondit,paffect!))
177+
@test sol[1,end] == 100

test/reactionsystem.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,13 @@ jumps[19] = VariableRateJump((u,p,t) -> p[19]*u[1]*t, integrator -> (integrator.
192192
jumps[20] = VariableRateJump((u,p,t) -> p[20]*t*u[1]*binomial(u[2],2)*u[3], integrator -> (integrator.u[2] -= 2; integrator.u[3] -= 1; integrator.u[4] += 2))
193193

194194
statetoid = Dict(state => i for (i,state) in enumerate(states(js)))
195-
parammap = map((x,y)->Pair(x,y),parameters(js),pars)
196-
for i in midxs
197-
maj = MT.assemble_maj(equations(js)[i], statetoid, ModelingToolkit.substituter(parammap),eltype(pars))
198-
@test abs(jumps[i].scaled_rates - maj.scaled_rates) < 100*eps()
199-
@test jumps[i].reactant_stoch == maj.reactant_stoch
200-
@test jumps[i].net_stoch == maj.net_stoch
195+
jspmapper = ModelingToolkit.JumpSysMajParamMapper(js, pars)
196+
symmaj = MT.assemble_maj(equations(js).x[1], statetoid, jspmapper)
197+
maj = MassActionJump(symmaj.param_mapper(pars), symmaj.reactant_stoch, symmaj.net_stoch, symmaj.param_mapper, scale_rates=false)
198+
for i in midxs
199+
@test abs(jumps[i].scaled_rates - maj.scaled_rates[i]) < 100*eps()
200+
@test jumps[i].reactant_stoch == maj.reactant_stoch[i]
201+
@test jumps[i].net_stoch == maj.net_stoch[i]
201202
end
202203
for i in cidxs
203204
crj = MT.assemble_crj(js, equations(js)[i], statetoid)

0 commit comments

Comments
 (0)