Skip to content

Commit 3160770

Browse files
committed
update get_variables! for Equations
1 parent db1c1db commit 3160770

File tree

4 files changed

+29
-29
lines changed

4 files changed

+29
-29
lines changed

src/systems/dependency_graphs.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
# each system type should define extract_variables! for a single equation
33
function equation_dependencies(sys::AbstractSystem; variables=states(sys))
44
eqs = equations(sys)
5-
deps = Set{Variable}()
5+
deps = Set{Operation}()
66
depeqs_to_vars = Vector{Vector{Variable}}(undef,length(eqs))
77

88
for (i,eq) in enumerate(eqs)
9-
depeqs_to_vars[i] = collect(get_variables!(deps, eq, variables))
9+
get_variables!(deps, eq, variables)
10+
depeqs_to_vars[i] = [convert(Variable,v) for v in deps]
1011
empty!(deps)
1112
end
1213

@@ -57,11 +58,11 @@ function variable_dependencies(sys::AbstractSystem; variables=states(sys), varia
5758
eqs = equations(sys)
5859
vtois = isnothing(variablestoids) ? Dict(convert(Variable, v) => i for (i,v) in enumerate(variables)) : variablestoids
5960

60-
deps = Set{Variable}()
61+
deps = Set{Operation}()
6162
badjlist = Vector{Vector{Int}}(undef, length(eqs))
6263
for (eidx,eq) in enumerate(eqs)
6364
modified_states!(deps, eq, variables)
64-
badjlist[eidx] = sort!([vtois[var] for var in deps])
65+
badjlist[eidx] = sort!([vtois[convert(Variable,var)] for var in deps])
6566
empty!(deps)
6667
end
6768

src/systems/jumps/jumpsystem.jl

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,39 +139,26 @@ end
139139

140140

141141
### Functions to determine which states a jump depends on
142-
function get_variables!(dep, jump::Union{ConstantRateJump,VariableRateJump}, variables)
143-
foreach(var -> (var in variables) && push!(dep, var), vars(jump.rate))
144-
dep
145-
end
142+
get_variables!(dep, jump::Union{ConstantRateJump,VariableRateJump}, variables) = get_variables!(dep, jump.rate, variables)
146143

147144
function get_variables!(dep, jump::MassActionJump, variables)
148-
jsr = jump.scaled_rates
149-
150-
if jsr isa Variable
151-
(jsr in variables) && push!(dep, jsr)
152-
elseif jsr isa Operation
153-
foreach(var -> (var in variables) && push!(dep, var), vars(jsr))
154-
end
155-
145+
get_variables!(dep, jump.scaled_rates, variables)
156146
for varasop in jump.reactant_stoch
157-
var = convert(Variable, varasop[1])
158-
(var in variables) && push!(dep, var)
147+
(varasop[1].op in variables) && push!(dep, varasop[1])
159148
end
160-
161149
dep
162150
end
163151

164152
### Functions to determine which states are modified by a given jump
165153
function modified_states!(mstates, jump::Union{ConstantRateJump,VariableRateJump}, sts)
166154
for eq in jump.affect!
167-
st = convert(Variable, eq.lhs)
168-
(st in sts) && push!(mstates, st)
155+
st = eq.lhs
156+
(st.op in sts) && push!(mstates, st)
169157
end
170158
end
171159

172160
function modified_states!(mstates, jump::MassActionJump, sts)
173161
for (state,stoich) in jump.net_stoch
174-
st = convert(Variable, state)
175-
(st in sts) && push!(mstates, st)
162+
(state.op in sts) && push!(mstates, state)
176163
end
177164
end

src/utils.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,24 @@ get_variables(O::Operation)
8888
8989
Returns the variables in the Operation
9090
"""
91-
get_variables(e::Constant, vars = Operation[]) = vars
92-
function get_variables(e::Operation, vars = Operation[])
91+
get_variables!(vars, e::Constant, varlist=nothing) = vars
92+
get_variables(e::Constant, varlist=nothing) = get_variables!(Operation[], e, varlist)
93+
94+
function get_variables!(vars, e::Operation, varlist=nothing)
9395
if is_singleton(e)
94-
push!(vars, e)
96+
(isnothing(varlist) ? true : (e.op in varlist)) && push!(vars, e)
9597
else
96-
foreach(x -> get_variables(x, vars), e.args)
98+
foreach(x -> get_variables!(vars, x, varlist), e.args)
9799
end
98100
return unique(vars)
99101
end
102+
get_variables(e::Operation, varlist=nothing) = get_variables!(Operation[], e, varlist)
103+
104+
function get_variables!(vars, e::Equation, varlist=nothing)
105+
get_variables!(vars, e.rhs, varlist)
106+
end
107+
get_variables(e::Equation, varlist=nothing) = get_variables!(Operation[],e,varlist)
108+
100109

101110
# variable substitution
102111
"""
@@ -119,4 +128,4 @@ function _substitute(expr, dict::Dict)
119128
simplify(SymbolicUtils.substitute(expr, dict))
120129
end
121130

122-
@deprecate substitute_expr!(expr,s) substitute(expr,s)
131+
@deprecate substitute_expr!(expr,s) substitute(expr,s)

test/jumpsystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,14 @@ jprob = JumpProblem(js3, dprob, Direct())
118118
m3 = getmean(jprob,Nsims)
119119
@test abs(m-m3)/m < .01
120120

121-
# maj jump test with dep graphs
121+
# maj jump test with various dep graphs
122122
js3b = JumpSystem([maj1,maj2], t, [S,I,R], [β,γ])
123123
jprobb = JumpProblem(js3b, dprob, NRM())
124124
m4 = getmean(jprobb,Nsims)
125125
@test abs(m-m4)/m < .01
126+
jprobc = JumpProblem(js3b, dprob, RSSA())
127+
m4 = getmean(jprobc,Nsims)
128+
@test abs(m-m4)/m < .01
126129

127130
# mass action jump tests for other reaction types (zero order, decay)
128131
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])

0 commit comments

Comments
 (0)