Skip to content

Commit 9771134

Browse files
Merge pull request #1129 from anandijain/aj/rxnsys_defaults
adding defaults to ReactionSystem constructors
2 parents d3f46aa + 07acf83 commit 9771134

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

src/systems/reaction/reactionsystem.jl

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -149,44 +149,38 @@ struct ReactionSystem <: AbstractSystem
149149
name::Symbol
150150
"""systems: The internal systems"""
151151
systems::Vector
152+
"""
153+
defaults: The default values to use when initial conditions and/or
154+
parameters are not supplied in `ODEProblem`.
155+
"""
156+
defaults::Dict
152157

153-
function ReactionSystem(eqs, iv, states, ps, observed, name, systems)
158+
function ReactionSystem(eqs, iv, states, ps, observed, name, systems, defaults)
154159
iv′ = value(iv)
155160
states′ = value.(states)
156161
ps′ = value.(ps)
157162
check_variables(states′, iv′)
158163
check_parameters(ps′, iv′)
159-
new(collect(eqs), iv′, states′, ps′, observed, name, systems)
164+
new(collect(eqs), iv′, states′, ps′, observed, name, systems, defaults)
160165
end
161166
end
162167

163168
function ReactionSystem(eqs, iv, species, params;
164169
observed = [],
165170
systems = [],
166-
name = gensym(:ReactionSystem))
171+
name = gensym(:ReactionSystem),
172+
default_u0=Dict(),
173+
default_p=Dict(),
174+
defaults=_merge(Dict(default_u0), Dict(default_p)))
167175

168176
#isempty(species) && error("ReactionSystems require at least one species.")
169-
ReactionSystem(eqs, iv, species, params, observed, name, systems)
177+
ReactionSystem(eqs, iv, species, params, observed, name, systems, defaults)
170178
end
171179

172180
function ReactionSystem(iv; kwargs...)
173181
ReactionSystem(Reaction[], iv, [], []; kwargs...)
174182
end
175183

176-
function equations(sys::ModelingToolkit.ReactionSystem)
177-
eqs = get_eqs(sys)
178-
systems = get_systems(sys)
179-
if isempty(systems)
180-
return eqs
181-
else
182-
eqs = [eqs;
183-
reduce(vcat,
184-
namespace_equations.(get_systems(sys));
185-
init=[])]
186-
return eqs
187-
end
188-
end
189-
190184
"""
191185
oderatelaw(rx; combinatoric_ratelaw=true)
192186
@@ -419,7 +413,7 @@ function Base.convert(::Type{<:ODESystem}, rs::ReactionSystem;
419413
name=nameof(rs), combinatoric_ratelaws=true, include_zero_odes=true, kwargs...)
420414
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, include_zero_odes=include_zero_odes)
421415
systems = map(sys -> (sys isa ODESystem) ? sys : convert(ODESystem, sys), get_systems(rs))
422-
ODESystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, kwargs...)
416+
ODESystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, defaults=get_defaults(rs), kwargs...)
423417
end
424418

425419
"""
@@ -439,7 +433,7 @@ function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem;
439433
name=nameof(rs), combinatoric_ratelaws=true, include_zero_odes=true, kwargs...)
440434
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, as_odes=false, include_zero_odes=include_zero_odes)
441435
systems = convert.(NonlinearSystem, get_systems(rs))
442-
NonlinearSystem(eqs, get_states(rs), get_ps(rs); name=name, systems=systems, kwargs...)
436+
NonlinearSystem(eqs, get_states(rs), get_ps(rs); name=name, systems=systems, defaults=get_defaults(rs), kwargs...)
443437
end
444438

445439
"""
@@ -487,6 +481,7 @@ function Base.convert(::Type{<:SDESystem}, rs::ReactionSystem;
487481
(noise_scaling===nothing) ? get_ps(rs) : union(get_ps(rs), toparam(noise_scaling));
488482
name=name,
489483
systems=systems,
484+
defaults=get_defaults(rs),
490485
kwargs...)
491486
end
492487

@@ -507,7 +502,7 @@ function Base.convert(::Type{<:JumpSystem},rs::ReactionSystem;
507502
name=nameof(rs), combinatoric_ratelaws=true, kwargs...)
508503
eqs = assemble_jumps(rs; combinatoric_ratelaws=combinatoric_ratelaws)
509504
systems = convert.(JumpSystem, get_systems(rs))
510-
JumpSystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, kwargs...)
505+
JumpSystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, defaults=get_defaults(rs), kwargs...)
511506
end
512507

513508

test/reactionsystem.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,35 @@ show(io, rs)
3434
str = String(take!(io))
3535
@test count(isequal('\n'), str) < 30
3636

37+
# defaults test
38+
def_p = [ki => float(i) for (i, ki) in enumerate(k)]
39+
def_u0 = [A => 0.5, B => 1., C=> 1.5, D => 2.0]
40+
defs = merge(Dict(def_p), Dict(def_u0))
41+
42+
rs = ReactionSystem(rxs,t,[A,B,C,D],k; defaults=defs)
43+
odesys = convert(ODESystem,rs)
44+
sdesys = convert(SDESystem,rs)
45+
js = convert(JumpSystem,rs)
46+
nlsys = convert(NonlinearSystem,rs)
47+
48+
@test ModelingToolkit.get_defaults(rs) ==
49+
ModelingToolkit.get_defaults(odesys) ==
50+
ModelingToolkit.get_defaults(sdesys) ==
51+
ModelingToolkit.get_defaults(js) ==
52+
ModelingToolkit.get_defaults(nlsys) ==
53+
defs
54+
55+
u0map = [A=>5.] # was 0.5
56+
pmap = [k[1]=>5.] # was 1.
57+
prob = ODEProblem(rs, u0map, (0,10.), pmap)
58+
@test prob.p[1] == 5.
59+
@test prob.u0[1] == 5.
60+
u0 = [10., 11., 12., 13.]
61+
ps = [float(x) for x in 100:119]
62+
prob = ODEProblem(rs, u0, (0,10.), ps)
63+
@test prob.p == ps
64+
@test prob.u0 == u0
65+
3766
# hard coded ODE rhs
3867
function oderhs(u,k,t)
3968
A = u[1]; B = u[2]; C = u[3]; D = u[4];

0 commit comments

Comments
 (0)