Skip to content

Commit 28370ac

Browse files
shashiYingboMa
andcommitted
Fix ReactionSystems tests
Co-authored-by: "Yingbo Ma" <[email protected]>
1 parent 01c0646 commit 28370ac

File tree

5 files changed

+56
-28
lines changed

5 files changed

+56
-28
lines changed

src/build_function.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ Special Keyword Argumnets:
209209
- `DaggerForm()`: Multithreading and multiprocessing using Julia's Dagger.jl
210210
for dynamic scheduling and load balancing.
211211
- `conv`: The conversion function of the Operation to Expr. By default this uses
212-
the `toexpr` function utilized in `convert(Expr,x)`.
212+
the `toexpr` function.
213213
- `checkbounds`: For whether to enable bounds checking inside of the generated
214214
function. Defaults to false, meaning that `@inbounds` is applied.
215215
- `linenumbers`: Determines whether the generated function expression retains

src/systems/jumps/jumpsystem.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,22 @@ function JumpSystem(eqs, iv, states, ps;
6767
JumpSystem{typeof(ap)}(ap, value(iv), value.(states), value.(ps), pins, observed, name, systems)
6868
end
6969

70-
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
71-
independent_variable(js),
72-
expression=Val{true})
70+
function generate_rate_function(js, rate)
71+
build_function(rate, states(js), parameters(js),
72+
independent_variable(js),
73+
conv = states_to_sym(states(js)),
74+
expression=Val{true})
75+
end
7376

74-
generate_affect_function(js, affect, outputidxs) = build_function(affect, states(js),
75-
parameters(js),
76-
independent_variable(js),
77-
expression=Val{true},
78-
headerfun=add_integrator_header,
79-
outputidxs=outputidxs)[2]
77+
function generate_affect_function(js, affect, outputidxs)
78+
build_function(affect, states(js),
79+
parameters(js),
80+
conv = states_to_sym(states(js)),
81+
independent_variable(js),
82+
expression=Val{true},
83+
headerfun=add_integrator_header,
84+
outputidxs=outputidxs)[2]
85+
end
8086

8187
function assemble_vrj(js, vrj, statetoid)
8288
rate = eval(generate_rate_function(js, vrj.rate))
@@ -121,7 +127,7 @@ end
121127
function numericrstoich(mtrs::Vector{Pair{V,W}}, statetoid) where {V,W}
122128
rs = Vector{Pair{Int,W}}()
123129
for (spec,stoich) in mtrs
124-
if !(spec isa Term) && iszero(spec)
130+
if !(spec isa Term) && _iszero(spec)
125131
push!(rs, 0 => stoich)
126132
else
127133
push!(rs, statetoid[value(spec)] => stoich)
@@ -134,8 +140,8 @@ end
134140
function numericnstoich(mtrs::Vector{Pair{V,W}}, statetoid) where {V,W}
135141
ns = Vector{Pair{Int,W}}()
136142
for (spec,stoich) in mtrs
137-
!(spec isa Term) && iszero(spec) && error("Net stoichiometry can not have a species labelled 0.")
138-
push!(ns, statetoid[value(spec)] => stoich)
143+
!(spec isa Term) && _iszero(spec) && error("Net stoichiometry can not have a species labelled 0.")
144+
push!(ns, statetoid[spec] => stoich)
139145
end
140146
sort!(ns)
141147
end
@@ -278,13 +284,13 @@ end
278284

279285
### Functions to determine which states a jump depends on
280286
function get_variables!(dep, jump::Union{ConstantRateJump,VariableRateJump}, variables)
281-
(jump.rate isa Operation) && get_variables!(dep, jump.rate, variables)
287+
(jump.rate isa Symbolic) && get_variables!(dep, jump.rate, variables)
282288
dep
283289
end
284290

285291
function get_variables!(dep, jump::MassActionJump, variables)
286292
sr = jump.scaled_rates
287-
(sr isa Term) && get_variables!(dep, sr, variables)
293+
(sr isa Symbolic) && get_variables!(dep, sr, variables)
288294
for varasop in jump.reactant_stoch
289295
any(isequal(varasop[1]), variables) && push!(dep, varasop[1])
290296
end

src/systems/reaction/reactionsystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ function assemble_jumps(rs; combinatoric_ratelaws=true)
349349
rl = jumpratelaw(rx, rxvars=rxvars, combinatoric_ratelaw=combinatoric_ratelaws)
350350
affect = Vector{Equation}()
351351
for (spec,stoich) in rx.netstoich
352-
push!(affect, var2op(spec) ~ var2op(spec) + stoich)
352+
push!(affect, spec ~ spec + stoich)
353353
end
354354
if haveivdep
355355
push!(veqs, VariableRateJump(rl,affect))
@@ -522,16 +522,16 @@ end
522522

523523
# determine which species a reaction depends on
524524
function get_variables!(deps::Set, rx::Reaction, variables)
525-
(rx.rate isa Term) && get_variables!(deps, rx.rate, variables)
525+
(rx.rate isa Symbolic) && get_variables!(deps, rx.rate, variables)
526526
for s in rx.substrates
527527
push!(deps, s)
528528
end
529529
@show deps
530530
end
531531

532532
# determine which species a reaction modifies
533-
function modified_states!(mstates, rx::Reaction, sts)
533+
function modified_states!(mstates, rx::Reaction, sts::Set)
534534
for (species,stoich) in rx.netstoich
535-
(species in sts) && push!(mstates, species())
535+
(species in sts) && push!(mstates, species)
536536
end
537537
end

src/utils.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,16 @@ Returns the variables in the expression
6565
"""
6666
get_variables(e::Num, varlist=nothing) = get_variables(value(e), varlist)
6767
get_variables!(vars, e, varlist=nothing) = vars
68-
get_variables!(vars, e::Sym, varlist=nothing) = push!(vars, e)
6968

7069
is_singleton(e::Term) = e.op isa Sym
7170
is_singleton(e::Sym) = true
7271
is_singleton(e) = false
7372

74-
function get_variables!(vars, e::Term, varlist=nothing)
73+
get_variables!(vars, e::Number, varlist=nothing) = vars
74+
75+
function get_variables!(vars, e::Symbolic, varlist=nothing)
7576
if is_singleton(e)
76-
if isnothing(varlist) || e in varlist
77+
if isnothing(varlist) || any(isequal(e), varlist)
7778
push!(vars, e)
7879
end
7980
else
@@ -86,9 +87,10 @@ function get_variables!(vars, e::Equation, varlist=nothing)
8687
get_variables!(vars, e.rhs, varlist)
8788
end
8889

90+
get_variables(e, varlist=nothing) = get_variables!([], e, varlist)
91+
8992
modified_states!(mstates, e::Equation, statelist=nothing) = get_variables!(mstates, e.lhs, statelist)
9093

91-
get_variables(e, varlist=nothing) = get_variables!([], e, varlist)
9294

9395
# variable substitution
9496
# Piracy but mild
@@ -122,3 +124,22 @@ macro showarr(x)
122124
end
123125

124126
@deprecate substitute_expr!(expr,s) substitute(expr,s)
127+
128+
function states_to_sym(states)
129+
function _states_to_sym(O)
130+
if O isa Equation
131+
Expr(:(=), _states_to_sym(O.lhs), _states_to_sym(O.rhs))
132+
elseif O isa Term
133+
if isa(O.op, Sym)
134+
any(isequal(O), states) && return O.op.name # dependent variables
135+
return build_expr(:call, Any[O.op.name; _states_to_sym.(O.args)])
136+
else
137+
return build_expr(:call, Any[Symbol(O.op); _states_to_sym.(O.args)])
138+
end
139+
elseif O isa Num
140+
return _states_to_sym(value(O))
141+
else
142+
return toexpr(O)
143+
end
144+
end
145+
end

test/reactionsystem.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ for i in cidxs
161161
crj = MT.assemble_crj(js, js.eqs[i], statetoid)
162162
@test isapprox(crj.rate(u0,p,time), jumps[i].rate(u0,p,time))
163163
fake_integrator1 = (u=zeros(4),p=p,t=0); fake_integrator2 = deepcopy(fake_integrator1);
164-
crj.affect!(fake_integrator1); jumps[i].affect!(fake_integrator2);
164+
crj.affect!(fake_integrator1);
165+
jumps[i].affect!(fake_integrator2);
165166
@test fake_integrator1 == fake_integrator2
166167
end
167168
for i in vidxs
@@ -190,12 +191,12 @@ rxs = [Reaction(k1*S, [S,I], [I], [2,3], [2]),
190191
rs = ReactionSystem(rxs, t, [S,I,R], [k1,k2])
191192
@test isequal(ModelingToolkit.oderatelaw(rs.eqs[1]), k1*S*S^2*I^3/(factorial(2)*factorial(3)))
192193
@test_skip isequal(ModelingToolkit.jumpratelaw(rs.eqs[1]), k1*S*binomial(S,2)*binomial(I,3))
193-
dep = Set{Operation}()
194-
ModelingToolkit.get_variables!(dep, rxs[2], states(rs))
194+
dep = Set()
195+
ModelingToolkit.get_variables!(dep, rxs[2], Set(states(rs)))
195196
dep2 = Set([R,I])
196197
@test dep == dep2
197-
dep = Set{Operation}()
198-
ModelingToolkit.modified_states!(dep, rxs[2], states(rs))
198+
dep = Set()
199+
ModelingToolkit.modified_states!(dep, rxs[2], Set(states(rs)))
199200
@test dep == Set([R,I])
200201

201202
isequal2(a,b) = isequal(simplify(a), simplify(b))

0 commit comments

Comments
 (0)