Skip to content

Commit 9f1ebfc

Browse files
Merge pull request #432 from SciML/s/substituter
`substituter` to avoid Dict re-construction
2 parents a32b5aa + 074b70d commit 9f1ebfc

File tree

4 files changed

+25
-36
lines changed

4 files changed

+25
-36
lines changed

src/direct.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ function simplified_expr(O::Operation)
4343
isempty(O.args) && return O.op.name
4444
return Expr(:call, Symbol(O.op), simplified_expr.(O.args)...)
4545
end
46+
if O.op === (^)
47+
if length(O.args) > 1 && O.args[2] isa Constant && O.args[2].value < 0
48+
return Expr(:call, :^, Expr(:call, :inv, simplified_expr(O.args[1])), -(O.args[2].value))
49+
end
50+
end
4651
return Expr(:call, Symbol(O.op), simplified_expr.(O.args)...)
4752
end
4853

src/systems/jumps/jumpsystem.jl

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,13 @@ function assemble_crj(js, crj, statetoid)
9191
end
9292

9393
function assemble_maj(js, maj::MassActionJump{U,Vector{Pair{V,W}},Vector{Pair{V2,W2}}},
94-
statetoid, parammap, pcontext) where {U,V,W,V2,W2}
95-
94+
statetoid, subber) where {U,V,W,V2,W2}
9695
sr = maj.scaled_rates
97-
if sr isa Operation
98-
if isempty(sr.args)
99-
pval = parammap[sr.op]
100-
else
101-
pval = Base.eval(pcontext, Expr(maj.scaled_rates))
102-
end
96+
if sr isa Operation
97+
pval = subber(sr).value
10398
elseif sr isa Variable
104-
pval = parammap[sr]
105-
else
99+
pval = subber(sr()).value
100+
else
106101
pval = maj.scaled_rates
107102
end
108103

@@ -169,20 +164,11 @@ sol = solve(jprob, SSAStepper())
169164
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
170165

171166
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
172-
parammap = Dict(convert(Variable,param) => prob.p[i] for (i,param) in enumerate(parameters(js)))
167+
parammap = map((x,y)->Pair(x(),y), parameters(js), prob.p)
173168
eqs = equations(js)
174-
175-
# for mass action jumps might need to evaluate parameter expressions
176-
# populate dummy module with params as local variables
177-
# (for eval-ing parameter expressions)
178-
pvars = parameters(js)
179-
param_context = Module()
180-
for (i, pval) in enumerate(prob.p)
181-
psym = Symbol(pvars[i])
182-
Base.eval(param_context, :($psym = $pval))
183-
end
184169

185-
majs = MassActionJump[assemble_maj(js, j, statetoid, parammap, param_context) for j in eqs.x[1]]
170+
subber = substituter(first.(parammap), last.(parammap))
171+
majs = MassActionJump[assemble_maj(js, j, statetoid, subber) for j in eqs.x[1]]
186172
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
187173
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
188174
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
@@ -225,4 +211,4 @@ function modified_states!(mstates, jump::MassActionJump, sts)
225211
for (state,stoich) in jump.net_stoch
226212
(state.op in sts) && push!(mstates, state)
227213
end
228-
end
214+
end

src/utils.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,16 @@ substitute(expr::Operation, s::Pair) = _substitute(expr, [s[1]], [s[2]])
121121
substitute(expr::Operation, dict::Dict) = _substitute(expr, keys(dict), values(dict))
122122
substitute(expr::Operation, s::Vector) = _substitute(expr, first.(s), last.(s))
123123

124-
function _substitute(expr, ks, vs)
125-
_substitute(expr, Dict(map(Pair, map(to_symbolic, ks), map(to_symbolic, vs))))
124+
function _substitute(ks, vs)
125+
expr -> _substitute(expr, Dict(map(Pair, map(to_symbolic, ks), map(to_symbolic, vs))))
126126
end
127127

128-
function _substitute(expr, dict::Dict)
129-
simplify(SymbolicUtils.substitute(expr, dict))
128+
function substituter(ks, vs)
129+
dict = Dict(map(Pair, map(to_symbolic, ks), map(to_symbolic, vs)))
130+
expr -> to_mtk(SymbolicUtils.simplify(SymbolicUtils.substitute(expr, dict)))
130131
end
131132

132-
@deprecate substitute_expr!(expr,s) substitute(expr,s)
133+
_substitute(expr, ks, vs) = substituter(ks, vs)(expr)
134+
135+
@deprecate substitute_expr!(expr,s) substitute(expr,s)
136+

test/reactionsystem.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,9 @@ jumps[19] = VariableRateJump((u,p,t) -> p[19]*u[1]*t, integrator -> (integrator.
115115
jumps[20] = VariableRateJump((u,p,t) -> p[20]*t*u[1]*binomial(u[2],2)*u[3], integrator -> (integrator.u[2] -= 2; integrator.u[3] -= 1; integrator.u[4] += 2))
116116

117117
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
118-
parammap = Dict(convert(Variable,param) => pars[i] for (i,param) in enumerate(parameters(js)))
119-
pvars = parameters(js)
120-
param_context = Module()
121-
for (i, pval) in enumerate(pars)
122-
psym = Symbol(pvars[i])
123-
Base.eval(param_context, :($psym = $pval))
124-
end
118+
parammap = map((x,y)->Pair(x(),y),parameters(js),pars)
125119
for i = 1:14
126-
maj = MT.assemble_maj(js, js.eqs[i], statetoid, parammap, param_context)
120+
maj = MT.assemble_maj(js, js.eqs[i], statetoid, ModelingToolkit.substituter(first.(parammap), last.(parammap)))
127121
@test abs(jumps[i].scaled_rates - maj.scaled_rates) < 100*eps()
128122
@test jumps[i].reactant_stoch == maj.reactant_stoch
129123
@test jumps[i].net_stoch == maj.net_stoch

0 commit comments

Comments
 (0)