Skip to content

Commit 874ae51

Browse files
committed
adding defaults to ReactionSystem constructors
1 parent d3f46aa commit 874ae51

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

src/systems/reaction/reactionsystem.jl

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -149,43 +149,51 @@ struct ReactionSystem <: AbstractSystem
149149
name::Symbol
150150
"""systems: The internal systems"""
151151
systems::Vector
152+
"""
153+
defaults: The default values to use when initial conditions and/or
154+
parameters are not supplied in `ODEProblem`.
155+
"""
156+
defaults::Dict
152157

153-
function ReactionSystem(eqs, iv, states, ps, observed, name, systems)
158+
function ReactionSystem(eqs, iv, states, ps, observed, name, systems, defaults)
154159
iv′ = value(iv)
155160
states′ = value.(states)
156161
ps′ = value.(ps)
157162
check_variables(states′, iv′)
158163
check_parameters(ps′, iv′)
159-
new(collect(eqs), iv′, states′, ps′, observed, name, systems)
164+
new(collect(eqs), iv′, states′, ps′, observed, name, systems, defaults)
160165
end
161166
end
162167

163168
function ReactionSystem(eqs, iv, species, params;
164169
observed = [],
165170
systems = [],
166-
name = gensym(:ReactionSystem))
171+
name = gensym(:ReactionSystem),
172+
default_u0=Dict(),
173+
default_p=Dict(),
174+
defaults=_merge(Dict(default_u0), Dict(default_p)))
167175

168176
#isempty(species) && error("ReactionSystems require at least one species.")
169-
ReactionSystem(eqs, iv, species, params, observed, name, systems)
177+
ReactionSystem(eqs, iv, species, params, observed, name, systems, defaults)
170178
end
171179

172180
function ReactionSystem(iv; kwargs...)
173181
ReactionSystem(Reaction[], iv, [], []; kwargs...)
174182
end
175183

176-
function equations(sys::ModelingToolkit.ReactionSystem)
177-
eqs = get_eqs(sys)
178-
systems = get_systems(sys)
179-
if isempty(systems)
180-
return eqs
181-
else
182-
eqs = [eqs;
183-
reduce(vcat,
184-
namespace_equations.(get_systems(sys));
185-
init=[])]
186-
return eqs
187-
end
188-
end
184+
# function equations(sys::ModelingToolkit.ReactionSystem)
185+
# eqs = get_eqs(sys)
186+
# systems = get_systems(sys)
187+
# if isempty(systems)
188+
# return eqs
189+
# else
190+
# eqs = [eqs;
191+
# reduce(vcat,
192+
# namespace_equations.(get_systems(sys));
193+
# init=[])]
194+
# return eqs
195+
# end
196+
# end
189197

190198
"""
191199
oderatelaw(rx; combinatoric_ratelaw=true)
@@ -419,7 +427,7 @@ function Base.convert(::Type{<:ODESystem}, rs::ReactionSystem;
419427
name=nameof(rs), combinatoric_ratelaws=true, include_zero_odes=true, kwargs...)
420428
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, include_zero_odes=include_zero_odes)
421429
systems = map(sys -> (sys isa ODESystem) ? sys : convert(ODESystem, sys), get_systems(rs))
422-
ODESystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, kwargs...)
430+
ODESystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, defaults=get_defaults(rs), kwargs...)
423431
end
424432

425433
"""
@@ -439,7 +447,7 @@ function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem;
439447
name=nameof(rs), combinatoric_ratelaws=true, include_zero_odes=true, kwargs...)
440448
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, as_odes=false, include_zero_odes=include_zero_odes)
441449
systems = convert.(NonlinearSystem, get_systems(rs))
442-
NonlinearSystem(eqs, get_states(rs), get_ps(rs); name=name, systems=systems, kwargs...)
450+
NonlinearSystem(eqs, get_states(rs), get_ps(rs); name=name, systems=systems, defaults=get_defaults(rs), kwargs...)
443451
end
444452

445453
"""
@@ -487,6 +495,7 @@ function Base.convert(::Type{<:SDESystem}, rs::ReactionSystem;
487495
(noise_scaling===nothing) ? get_ps(rs) : union(get_ps(rs), toparam(noise_scaling));
488496
name=name,
489497
systems=systems,
498+
defaults=get_defaults(rs),
490499
kwargs...)
491500
end
492501

@@ -507,7 +516,7 @@ function Base.convert(::Type{<:JumpSystem},rs::ReactionSystem;
507516
name=nameof(rs), combinatoric_ratelaws=true, kwargs...)
508517
eqs = assemble_jumps(rs; combinatoric_ratelaws=combinatoric_ratelaws)
509518
systems = convert.(JumpSystem, get_systems(rs))
510-
JumpSystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, kwargs...)
519+
JumpSystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, defaults=get_defaults(rs), kwargs...)
511520
end
512521

513522

test/reactionsystem.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@ show(io, rs)
3434
str = String(take!(io))
3535
@test count(isequal('\n'), str) < 30
3636

37+
# defaults test
38+
def_p = [ki => float(i) for (i, ki) in enumerate(k)]
39+
def_u0 = [A => 0.5, B => 1., C=> 1.5, D => 2.0]
40+
defs = merge(Dict(def_p), Dict(def_u0))
41+
42+
rs = ReactionSystem(rxs,t,[A,B,C,D],k; defaults=defs)
43+
odesys = convert(ODESystem,rs)
44+
sdesys = convert(SDESystem,rs)
45+
js = convert(JumpSystem,rs)
46+
nlsys = convert(NonlinearSystem,rs)
47+
48+
@test ModelingToolkit.get_defaults(rs) ==
49+
ModelingToolkit.get_defaults(odesys) ==
50+
ModelingToolkit.get_defaults(sdesys) ==
51+
ModelingToolkit.get_defaults(js) ==
52+
ModelingToolkit.get_defaults(nlsys) ==
53+
defs
54+
3755
# hard coded ODE rhs
3856
function oderhs(u,k,t)
3957
A = u[1]; B = u[2]; C = u[3]; D = u[4];

0 commit comments

Comments
 (0)