Skip to content

Commit 22bbafc

Browse files
committed
jumpsystem fixes
1 parent 8006a54 commit 22bbafc

File tree

4 files changed

+36
-45
lines changed

4 files changed

+36
-45
lines changed

src/systems/dependency_graphs.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ equation_dependencies(jumpsys, variables=parameters(jumpsys))
4343
"""
4444
function equation_dependencies(sys::AbstractSystem; variables=states(sys))
4545
eqs = equations(sys)
46-
deps = Set{Operation}()
47-
depeqs_to_vars = Vector{Vector{Variable}}(undef,length(eqs))
46+
deps = Set()
47+
depeqs_to_vars = Vector{Vector}(undef,length(eqs))
4848

49-
for (i,eq) in enumerate(eqs)
49+
for (i,eq) in enumerate(eqs)
5050
get_variables!(deps, eq, variables)
51-
depeqs_to_vars[i] = [convert(Variable,v) for v in deps]
51+
depeqs_to_vars[i] = [value(v) for v in deps]
5252
empty!(deps)
5353
end
5454

@@ -165,7 +165,7 @@ digr = asgraph(odesys)
165165
```
166166
"""
167167
function asgraph(sys::AbstractSystem; variables=states(sys),
168-
variablestoids=Dict(convert(Variable, v) => i for (i,v) in enumerate(variables)))
168+
variablestoids=Dict(v => i for (i,v) in enumerate(variables)))
169169
asgraph(equation_dependencies(sys, variables=variables), variablestoids)
170170
end
171171

@@ -190,13 +190,13 @@ variable_dependencies(odesys)
190190
"""
191191
function variable_dependencies(sys::AbstractSystem; variables=states(sys), variablestoids=nothing)
192192
eqs = equations(sys)
193-
vtois = isnothing(variablestoids) ? Dict(convert(Variable, v) => i for (i,v) in enumerate(variables)) : variablestoids
193+
vtois = isnothing(variablestoids) ? Dict(v => i for (i,v) in enumerate(variables)) : variablestoids
194194

195-
deps = Set{Operation}()
195+
deps = Set()
196196
badjlist = Vector{Vector{Int}}(undef, length(eqs))
197197
for (eidx,eq) in enumerate(eqs)
198198
modified_states!(deps, eq, variables)
199-
badjlist[eidx] = sort!([vtois[convert(Variable,var)] for var in deps])
199+
badjlist[eidx] = sort!([vtois[var] for var in deps])
200200
empty!(deps)
201201
end
202202

src/systems/jumps/jumpsystem.jl

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ generate_affect_function(js, affect, outputidxs) = build_function(affect, states
8080

8181
function assemble_vrj(js, vrj, statetoid)
8282
rate = eval(generate_rate_function(js, vrj.rate))
83-
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!)
83+
outputvars = (value(affect.lhs) for affect in vrj.affect!)
8484
outputidxs = ((statetoid[var] for var in outputvars)...,)
8585
affect = eval(generate_affect_function(js, vrj.affect!, outputidxs))
8686
VariableRateJump(rate, affect)
8787
end
8888

8989
function assemble_vrj_expr(js, vrj, statetoid)
9090
rate = generate_rate_function(js, vrj.rate)
91-
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!)
91+
outputvars = (value(affect.lhs) for affect in vrj.affect!)
9292
outputidxs = ((statetoid[var] for var in outputvars)...,)
9393
affect = generate_affect_function(js, vrj.affect!, outputidxs)
9494
quote
@@ -100,15 +100,15 @@ end
100100

101101
function assemble_crj(js, crj, statetoid)
102102
rate = eval(generate_rate_function(js, crj.rate))
103-
outputvars = (convert(Variable,affect.lhs) for affect in crj.affect!)
103+
outputvars = (value(affect.lhs) for affect in crj.affect!)
104104
outputidxs = ((statetoid[var] for var in outputvars)...,)
105105
affect = eval(generate_affect_function(js, crj.affect!, outputidxs))
106106
ConstantRateJump(rate, affect)
107107
end
108108

109109
function assemble_crj_expr(js, crj, statetoid)
110110
rate = generate_rate_function(js, crj.rate)
111-
outputvars = (convert(Variable,affect.lhs) for affect in crj.affect!)
111+
outputvars = (value(affect.lhs) for affect in crj.affect!)
112112
outputidxs = ((statetoid[var] for var in outputvars)...,)
113113
affect = generate_affect_function(js, crj.affect!, outputidxs)
114114
quote
@@ -118,24 +118,13 @@ function assemble_crj_expr(js, crj, statetoid)
118118
end
119119
end
120120

121-
function numericrate(rate, subber)
122-
if rate isa Operation
123-
rval = subber(rate).value
124-
elseif rate isa Variable
125-
rval = subber(rate()).value
126-
else
127-
rval = rate
128-
end
129-
rval
130-
end
131-
132121
function numericrstoich(mtrs::Vector{Pair{V,W}}, statetoid) where {V,W}
133122
rs = Vector{Pair{Int,W}}()
134123
for (spec,stoich) in mtrs
135-
if !(spec isa Operation) && iszero(spec)
124+
if !(spec isa Term) && iszero(spec)
136125
push!(rs, 0 => stoich)
137126
else
138-
push!(rs, statetoid[convert(Variable,spec)] => stoich)
127+
push!(rs, statetoid[value(spec)] => stoich)
139128
end
140129
end
141130
sort!(rs)
@@ -145,18 +134,19 @@ end
145134
function numericnstoich(mtrs::Vector{Pair{V,W}}, statetoid) where {V,W}
146135
ns = Vector{Pair{Int,W}}()
147136
for (spec,stoich) in mtrs
148-
!(spec isa Operation) && iszero(spec) && error("Net stoichiometry can not have a species labelled 0.")
149-
push!(ns, statetoid[convert(Variable,spec)] => stoich)
137+
!(spec isa Term) && iszero(spec) && error("Net stoichiometry can not have a species labelled 0.")
138+
push!(ns, statetoid[value(spec)] => stoich)
150139
end
151140
sort!(ns)
152141
end
153142

154143
# assemble a numeric MassActionJump from a MT MassActionJump representing one rx.
155144
function assemble_maj(maj::MassActionJump, statetoid, subber, invttype)
156-
rval = numericrate(maj.scaled_rates, subber)
145+
rval = subber(maj.scaled_rates)
157146
rs = numericrstoich(maj.reactant_stoch, statetoid)
158147
ns = numericnstoich(maj.net_stoch, statetoid)
159-
maj = MassActionJump(convert(invttype, rval), rs, ns, scale_rates = false)
148+
@show rval
149+
maj = MassActionJump(convert(invttype, value(rval)), rs, ns, scale_rates = false)
160150
maj
161151
end
162152

@@ -192,11 +182,11 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Tuple,
192182
parammap=DiffEqBase.NullParameters(); kwargs...)
193183

194184
(u0map isa AbstractVector) || error("For DiscreteProblems u0map must be an AbstractVector.")
195-
u0d = Dict( convert(Variable,u[1]) => u[2] for u in u0map)
185+
u0d = Dict( value(u[1]) => u[2] for u in u0map)
196186
u0 = [u0d[u] for u in states(sys)]
197187
if parammap != DiffEqBase.NullParameters()
198188
(parammap isa AbstractVector) || error("For DiscreteProblems parammap must be an AbstractVector.")
199-
pd = Dict( convert(Variable,u[1]) => u[2] for u in parammap)
189+
pd = Dict( value(u[1]) => u[2] for u in parammap)
200190
p = [pd[u] for u in parameters(sys)]
201191
else
202192
p = parammap
@@ -257,15 +247,16 @@ sol = solve(jprob, SSAStepper())
257247
"""
258248
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
259249

260-
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
250+
statetoid = Dict(value(state) => i for (i,state) in enumerate(states(js)))
261251
eqs = equations(js)
262252
invttype = typeof(1 / prob.tspan[2])
263253

264254
# handling parameter substition and empty param vecs
265255
p = (prob.p == DiffEqBase.NullParameters()) ? Operation[] : prob.p
266-
parammap = map((x,y)->Pair(x(),y), parameters(js), p)
256+
parammap = map((x,y)->Pair(x,y), parameters(js), p)
267257
subber = substituter(parammap)
268258

259+
@show parammap
269260
majs = MassActionJump[assemble_maj(j, statetoid, subber, invttype) for j in eqs.x[1]]
270261
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
271262
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
@@ -294,9 +285,9 @@ end
294285

295286
function get_variables!(dep, jump::MassActionJump, variables)
296287
sr = jump.scaled_rates
297-
(sr isa Operation) && get_variables!(dep, sr, variables)
288+
(sr isa Term) && get_variables!(dep, sr, variables)
298289
for varasop in jump.reactant_stoch
299-
(varasop[1].op in variables) && push!(dep, varasop[1])
290+
any(isequal(varasop[1]), variables) && push!(dep, varasop[1])
300291
end
301292
dep
302293
end
@@ -305,12 +296,12 @@ end
305296
function modified_states!(mstates, jump::Union{ConstantRateJump,VariableRateJump}, sts)
306297
for eq in jump.affect!
307298
st = eq.lhs
308-
(st.op in sts) && push!(mstates, st)
299+
any(isequal(st), sts) && push!(mstates, st)
309300
end
310301
end
311302

312303
function modified_states!(mstates, jump::MassActionJump, sts)
313304
for (state,stoich) in jump.net_stoch
314-
(state.op in sts) && push!(mstates, state)
305+
any(isequal(state), sts) && push!(mstates, state)
315306
end
316307
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ substitute(expr::Term, s::Vector; kw...) = substitute(expr, Dict(s); kw...)
9595

9696
function substituter(pairs)
9797
dict = Dict(to_symbolic(k) => to_symbolic(v) for (k, v) in pairs)
98-
expr -> to_mtk(SymbolicUtils.substitute(expr, dict))
98+
expr -> SymbolicUtils.substitute(expr, dict)
9999
end
100100

101101
macro showarr(x)

test/jumpsystem.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ affect₂ = [I ~ I - 1, R ~ R + 1]
1111
j₁ = ConstantRateJump(rate₁,affect₁)
1212
j₂ = VariableRateJump(rate₂,affect₂)
1313
js = JumpSystem([j₁,j₂], t, [S,I,R], [β,γ])
14-
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
14+
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
1515
mtjump1 = MT.assemble_crj(js, j₁, statetoid)
1616
mtjump2 = MT.assemble_vrj(js, j₂, statetoid)
1717

@@ -116,8 +116,8 @@ m2 = getmean(jprob,Nsims)
116116
maj1 = MassActionJump(2*β/2, [S => 1, I => 1], [S => -1, I => 1])
117117
maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
118118
js3 = JumpSystem([maj1,maj2], t, [S,I,R], [β,γ])
119-
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
120-
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
119+
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
120+
ptoid = Dict(MT.value(par) => i for (i,par) in enumerate(parameters(js)))
121121
dprob = DiscreteProblem(js3, u₀map, tspan, parammap)
122122
jprob = JumpProblem(js3, dprob, Direct())
123123
m3 = getmean(jprob,Nsims)
@@ -136,8 +136,8 @@ m4 = getmean(jprobc,Nsims)
136136
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
137137
maj2 = MassActionJump(γ, [S => 1], [S => -1])
138138
js4 = JumpSystem([maj1,maj2], t, [S], [β,γ])
139-
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
140-
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
139+
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
140+
ptoid = Dict(MT.value(par) => i for (i,par) in enumerate(parameters(js)))
141141
dprob = DiscreteProblem(js4, [S => 999], (0,1000.), [β => 100.=> .01])
142142
jprob = JumpProblem(js4, dprob, Direct())
143143
m4 = getmean(jprob,Nsims)
@@ -147,8 +147,8 @@ m4 = getmean(jprob,Nsims)
147147
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
148148
maj2 = MassActionJump(γ, [S => 2], [S => -1])
149149
js4 = JumpSystem([maj1,maj2], t, [S], [β,γ])
150-
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
151-
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
150+
statetoid = Dict(MT.value(state) => i for (i,state) in enumerate(states(js)))
151+
ptoid = Dict(MT.value(par) => i for (i,par) in enumerate(parameters(js)))
152152
dprob = DiscreteProblem(js4, [S => 999], (0,1000.), [β => 100.=> .01])
153153
jprob = JumpProblem(js4, dprob, Direct())
154154
sol = solve(jprob, SSAStepper());

0 commit comments

Comments
 (0)