Skip to content

Commit c4c3737

Browse files
Merge pull request #438 from isaacsas/float-int-fix
fix #436
2 parents 467e82d + dff61a5 commit c4c3737

File tree

3 files changed

+41
-16
lines changed

3 files changed

+41
-16
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function assemble_crj(js, crj, statetoid)
9191
end
9292

9393
function assemble_maj(js, maj::MassActionJump{U,Vector{Pair{V,W}},Vector{Pair{V2,W2}}},
94-
statetoid, subber) where {U,V,W,V2,W2}
94+
statetoid, subber, invttype) where {U,V,W,V2,W2}
9595
sr = maj.scaled_rates
9696
if sr isa Operation
9797
pval = subber(sr).value
@@ -118,7 +118,8 @@ function assemble_maj(js, maj::MassActionJump{U,Vector{Pair{V,W}},Vector{Pair{V2
118118
end
119119
sort!(ns)
120120

121-
MassActionJump(pval, rs, ns, scale_rates = false)
121+
maj = MassActionJump(convert(invttype, pval), rs, ns, scale_rates = false)
122+
return maj
122123
end
123124

124125
"""
@@ -163,12 +164,16 @@ sol = solve(jprob, SSAStepper())
163164
"""
164165
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
165166

166-
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
167-
parammap = map((x,y)->Pair(x(),y), parameters(js), prob.p)
167+
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
168168
eqs = equations(js)
169-
170-
subber = substituter(first.(parammap), last.(parammap))
171-
majs = MassActionJump[assemble_maj(js, j, statetoid, subber) for j in eqs.x[1]]
169+
invttype = typeof(1 / prob.tspan[2])
170+
171+
# handling parameter substition and empty param vecs
172+
p = (prob.p == DiffEqBase.NullParameters()) ? Operation[] : prob.p
173+
parammap = map((x,y)->Pair(x(),y), parameters(js), p)
174+
subber = substituter(first.(parammap), last.(parammap))
175+
176+
majs = MassActionJump[assemble_maj(js, j, statetoid, subber, invttype) for j in eqs.x[1]]
172177
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
173178
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
174179
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")

src/systems/reaction/reactionsystem.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,10 @@ end
138138
function ReactionSystem(eqs, iv, species, params; systems = ReactionSystem[],
139139
name = gensym(:ReactionSystem))
140140

141-
ReactionSystem(eqs, iv, convert.(Variable,species), convert.(Variable,params),
142-
name, systems)
141+
142+
isempty(species) && error("ReactionSystems require at least one species.")
143+
paramvars = isempty(params) ? Variable[] : convert.(Variable, params)
144+
ReactionSystem(eqs, iv, convert.(Variable,species), paramvars, name, systems)
143145
end
144146

145147
# Calculate the ODE rate law
@@ -233,14 +235,19 @@ explicitly on the independent variable (usually time).
233235
"""
234236
function ismassaction(rx, rs; rxvars = get_variables(rx.rate),
235237
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars))
236-
return !(haveivdep || rx.only_use_rate || any(convert(Variable,rxv) in states(rs) for rxv in rxvars))
238+
# if no dependencies must be zero order
239+
if isempty(rxvars)
240+
return true
241+
else
242+
return !(haveivdep || rx.only_use_rate || any(convert(Variable,rxv) in states(rs) for rxv in rxvars))
243+
end
237244
end
238245

239246
function assemble_jumps(rs)
240247
eqs = Vector{Union{ConstantRateJump, MassActionJump, VariableRateJump}}()
241248

242-
for rx in equations(rs)
243-
rxvars = get_variables(rx.rate)
249+
for rx in equations(rs)
250+
rxvars = (rx.rate isa Operation) ? get_variables(rx.rate) : Operation[]
244251
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars)
245252
if ismassaction(rx, rs; rxvars=rxvars, haveivdep=haveivdep)
246253
reactant_stoch = isempty(rx.substoich) ? [0 => 1] : [var2op(sub.op) => stoich for (sub,stoich) in zip(rx.substrates,rx.substoich)]

test/reactionsystem.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ js = convert(JumpSystem, rs)
8888
@test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.ConstantRateJump, 15:18))
8989
@test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.VariableRateJump, 19:20))
9090

91-
pars = rand(length(k)); u0 = rand(1:100,4); time = rand();
91+
pars = rand(length(k)); u0 = rand(1:10,4); time = rand();
9292
jumps = Vector{Union{ConstantRateJump, MassActionJump, VariableRateJump}}(undef,length(js.eqs))
9393

9494
jumps[1] = MassActionJump(pars[1], Vector{Pair{Int,Int}}(), [1 => 1]);
@@ -117,23 +117,36 @@ jumps[20] = VariableRateJump((u,p,t) -> p[20]*t*u[1]*binomial(u[2],2)*u[3], inte
117117
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
118118
parammap = map((x,y)->Pair(x(),y),parameters(js),pars)
119119
for i = 1:14
120-
maj = MT.assemble_maj(js, js.eqs[i], statetoid, ModelingToolkit.substituter(first.(parammap), last.(parammap)))
120+
maj = MT.assemble_maj(js, js.eqs[i], statetoid, ModelingToolkit.substituter(first.(parammap), last.(parammap)),eltype(pars))
121121
@test abs(jumps[i].scaled_rates - maj.scaled_rates) < 100*eps()
122122
@test jumps[i].reactant_stoch == maj.reactant_stoch
123123
@test jumps[i].net_stoch == maj.net_stoch
124124
end
125125
for i = 15:18
126126
(i==16) && continue
127127
crj = MT.assemble_crj(js, js.eqs[i], statetoid)
128-
@test abs(crj.rate(u0,p,time) - jumps[i].rate(u0,p,time)) < 100*eps()
128+
@test isapprox(crj.rate(u0,p,time), jumps[i].rate(u0,p,time))
129129
fake_integrator1 = (u=zeros(4),p=p,t=0); fake_integrator2 = deepcopy(fake_integrator1);
130130
crj.affect!(fake_integrator1); jumps[i].affect!(fake_integrator2);
131131
@test fake_integrator1 == fake_integrator2
132132
end
133133
for i = 19:20
134134
crj = MT.assemble_vrj(js, js.eqs[i], statetoid)
135-
@test abs(crj.rate(u0,p,time) - jumps[i].rate(u0,p,time)) < 100*eps()
135+
@test isapprox(crj.rate(u0,p,time), jumps[i].rate(u0,p,time))
136136
fake_integrator1 = (u=zeros(4),p=p,t=0.); fake_integrator2 = deepcopy(fake_integrator1);
137137
crj.affect!(fake_integrator1); jumps[i].affect!(fake_integrator2);
138138
@test fake_integrator1 == fake_integrator2
139139
end
140+
141+
142+
# test for https://github.com/SciML/ModelingToolkit.jl/issues/436
143+
@parameters t
144+
@variables S I
145+
rxs = [Reaction(1,[S],[I]), Reaction(1.1,[S],[I])]
146+
rs = ReactionSystem(rxs, t, [S,I], [])
147+
js = convert(JumpSystem, rs)
148+
dprob = DiscreteProblem(js, [S => 1, I => 1], (0.0,10.0))
149+
jprob = JumpProblem(js, dprob, Direct())
150+
sol = solve(jprob, SSAStepper())
151+
152+
nothing

0 commit comments

Comments
 (0)