Skip to content

Commit d9d17e9

Browse files
committed
Instead of collecting constants up front and adding them to the fields of a system, collect the constants just prior to building the functions.
1 parent c0372aa commit d9d17e9

File tree

6 files changed

+68
-39
lines changed

6 files changed

+68
-39
lines changed

src/systems/abstractsystem.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ for prop in [:eqs
167167
:iv
168168
:states
169169
:ps
170-
:cs
171170
:var_to_name
172171
:ctrls
173172
:defaults
@@ -377,7 +376,6 @@ end
377376
namespace_variables(sys::AbstractSystem) = states(sys, states(sys))
378377
namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
379378
namespace_controls(sys::AbstractSystem) = controls(sys, controls(sys))
380-
namespace_constants(sys::AbstractSystem) = constants(sys, constants(sys))
381379

382380
function namespace_defaults(sys)
383381
defs = defaults(sys)
@@ -439,12 +437,6 @@ function parameters(sys::AbstractSystem)
439437
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
440438
end
441439

442-
function constants(sys::AbstractSystem)
443-
cs = get_cs(sys)
444-
systems = get_systems(sys)
445-
unique(isempty(systems) ? cs : [cs; reduce(vcat, namespace_constants.(systems))])
446-
end
447-
448440
function controls(sys::AbstractSystem)
449441
ctrls = get_ctrls(sys)
450442
systems = get_systems(sys)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
128128
[eq.rhs for eq in eqs]
129129

130130
# Swap constants for their values
131-
cs = constants(sys)
131+
cs = collect_constants(eqs)
132132
if !isempty(cs) > 0
133133
cmap = map(x -> x => getdefault(x), cs)
134134
rhss = map(x -> substitute(x, cmap), rhss)
@@ -142,6 +142,7 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
142142
pre, sol_states = get_substitutions_and_solved_states(sys,
143143
no_postprocess = has_difference)
144144

145+
145146
if implicit_dae
146147
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre, states = sol_states,
147148
kwargs...)

src/systems/diffeqs/odesystem.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ struct ODESystem <: AbstractODESystem
3636
states::Vector
3737
"""Parameter variables. Must not contain the independent variable."""
3838
ps::Vector
39-
"""Symbolic constants."""
40-
cs::Vector
4139
"""Array variables."""
4240
var_to_name::Any
4341
"""Control parameters (some subset of `ps`)."""
@@ -122,7 +120,7 @@ struct ODESystem <: AbstractODESystem
122120
"""
123121
metadata::Any
124122

125-
function ODESystem(deqs, iv, dvs, ps, cs, var_to_name, ctrls, observed, tgrad,
123+
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
126124
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
127125
torn_matching, connector_type, connections, preface, cevents,
128126
devents, tearing_state = nothing, substitutions = nothing,
@@ -135,16 +133,16 @@ struct ODESystem <: AbstractODESystem
135133
check_equations(equations(cevents), iv)
136134
end
137135
if checks == true || (checks & CheckUnits) > 0
138-
all_dimensionless([dvs; ps; iv; cs]) || check_units(deqs)
136+
all_dimensionless([dvs; ps; iv]) || check_units(deqs)
139137
end
140-
new(deqs, iv, dvs, ps, cs, var_to_name, ctrls, observed, tgrad, jac,
138+
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
141139
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
142140
connector_type, connections, preface, cevents, devents, tearing_state,
143141
substitutions, metadata)
144142
end
145143
end
146144

147-
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps, cs;
145+
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
148146
controls = Num[],
149147
observed = Equation[],
150148
systems = ODESystem[],
@@ -166,7 +164,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps, cs;
166164
iv′ = value(scalarize(iv))
167165
dvs′ = value.(scalarize(dvs))
168166
ps′ = value.(scalarize(ps))
169-
cs′ = value.(scalarize(cs))
170167
ctrl′ = value.(scalarize(controls))
171168

172169
if !(isempty(default_u0) && isempty(default_p))
@@ -179,7 +176,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps, cs;
179176
var_to_name = Dict()
180177
process_variables!(var_to_name, defaults, dvs′)
181178
process_variables!(var_to_name, defaults, ps′)
182-
process_variables!(var_to_name, defaults, cs′)
183179
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
184180

185181
tgrad = RefValue(EMPTY_TGRAD)
@@ -193,7 +189,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps, cs;
193189
end
194190
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
195191
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
196-
ODESystem(deqs, iv′, dvs′, ps′, cs′, var_to_name, ctrl′, observed, tgrad, jac,
192+
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
197193
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
198194
connector_type, nothing, preface, cont_callbacks, disc_callbacks,
199195
metadata, checks = checks)
@@ -205,7 +201,6 @@ function ODESystem(eqs, iv = nothing; kwargs...)
205201
diffvars = OrderedSet()
206202
allstates = OrderedSet()
207203
ps = OrderedSet()
208-
cs = OrderedSet() #Constants
209204
# reorder equations such that it is in the form of `diffeq, algeeq`
210205
diffeq = Equation[]
211206
algeeq = Equation[]
@@ -223,8 +218,8 @@ function ODESystem(eqs, iv = nothing; kwargs...)
223218
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
224219
for eq in eqs
225220
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
226-
collect_vars!(allstates, ps, cs, eq.lhs, iv)
227-
collect_vars!(allstates, ps, cs, eq.rhs, iv)
221+
collect_vars!(allstates, ps, eq.lhs, iv)
222+
collect_vars!(allstates, ps, eq.rhs, iv)
228223
if isdiffeq(eq)
229224
diffvar, _ = var_from_nested_derivative(eq.lhs)
230225
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
@@ -240,7 +235,16 @@ function ODESystem(eqs, iv = nothing; kwargs...)
240235
algevars = setdiff(allstates, diffvars)
241236
# the orders here are very important!
242237
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
243-
collect(Iterators.flatten((diffvars, algevars))), ps, cs; kwargs...)
238+
collect(Iterators.flatten((diffvars, algevars))), ps; kwargs...)
239+
end
240+
241+
function collect_constants(eqs) #Does this need to be different for other system types?
242+
constants = Set()
243+
for eq in eqs
244+
collect_constants!(constants,eq.lhs)
245+
collect_constants!(constants,eq.rhs)
246+
end
247+
return collect(constants)
244248
end
245249

246250
# NOTE: equality does not check cached Jacobian
@@ -265,7 +269,6 @@ function flatten(sys::ODESystem, noeqs = false)
265269
get_iv(sys),
266270
states(sys),
267271
parameters(sys),
268-
constants(sys),
269272
observed = observed(sys),
270273
continuous_events = continuous_events(sys),
271274
discrete_events = discrete_events(sys),
@@ -298,6 +301,13 @@ function build_explicit_observed_function(sys, ts;
298301
dep_vars = scalarize(setdiff(vars, ivs))
299302

300303
obs = observed(sys)
304+
305+
cs = collect_constants(obs)
306+
if !isempty(cs) > 0
307+
cmap = map(x -> x => getdefault(x), cs)
308+
obs = map(x -> x.lhs ~ substitute(x.rhs, cmap), obs)
309+
end
310+
301311
sts = Set(states(sys))
302312
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
303313
namespaced_to_obs = Dict(states(sys, x.lhs) => x.lhs for x in obs)

src/utils.jl

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -466,46 +466,61 @@ function find_derivatives!(vars, expr, f)
466466
return vars
467467
end
468468

469-
function collect_vars!(states, parameters, constants, expr, iv)
469+
function collect_vars!(states, parameters, expr, iv)
470470
if expr isa Sym
471-
collect_var!(states, parameters, constants, expr, iv)
471+
collect_var!(states, parameters, expr, iv)
472472
else
473473
for var in vars(expr)
474474
if istree(var) && operation(var) isa Differential
475475
var, _ = var_from_nested_derivative(var)
476476
end
477-
collect_var!(states, parameters, constants, var, iv)
477+
collect_var!(states, parameters, var, iv)
478478
end
479479
end
480480
return nothing
481481
end
482482

483-
function collect_vars_difference!(states, parameters, constants, expr, iv)
483+
function collect_constants!(constants, expr)
484484
if expr isa Sym
485-
collect_var!(states, parameters, constants, expr, iv)
485+
collect_constant!(constants, expr)
486+
else
487+
for var in vars(expr)
488+
collect_constant!(constants, var)
489+
end
490+
end
491+
return nothing
492+
end
493+
494+
function collect_vars_difference!(states, parameters, expr, iv)
495+
if expr isa Sym
496+
collect_var!(states, parameters, expr, iv)
486497
else
487498
for var in vars(expr)
488499
if istree(var) && operation(var) isa Difference
489500
var, _ = var_from_nested_difference(var)
490501
end
491-
collect_var!(states, parameters, constants, var, iv)
502+
collect_var!(states, parameters, var, iv)
492503
end
493504
end
494505
return nothing
495506
end
496507

497-
function collect_var!(states, parameters, constants, var, iv)
508+
function collect_var!(states, parameters, var, iv)
498509
isequal(var, iv) && return nothing
499510
if isparameter(var) || (istree(var) && isparameter(operation(var)))
500511
push!(parameters, var)
501-
elseif isconstant(var)
502-
push!(constants,var)
503-
else
512+
elseif !isconstant(var)
504513
push!(states, var)
505514
end
506515
return nothing
507516
end
508517

518+
function collect_constant!(constants, var)
519+
if isconstant(var)
520+
push!(constants,var)
521+
end
522+
end
523+
509524
function get_postprocess_fbody(sys)
510525
if has_preface(sys) && (pre = preface(sys); pre !== nothing)
511526
pre_ = let pre = pre
@@ -549,6 +564,13 @@ function get_substitutions_and_solved_states(sys; no_postprocess = false)
549564
pre = no_postprocess ? (ex -> ex) : get_postprocess_fbody(sys)
550565
else
551566
@unpack subs = get_substitutions(sys)
567+
# Swap constants for their values
568+
cs = collect_constants(subs)
569+
if !isempty(cs) > 0
570+
cmap = map(x -> x => getdefault(x), cs)
571+
subs = map(x -> x.lhs ~ substitute(x.rhs, cmap), subs)
572+
end
573+
552574
sol_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
553575
if no_postprocess
554576
pre = ex -> Let(Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs], ex,

test/constants.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@ MT = ModelingToolkit
99
D = Differential(t)
1010
eqs = [D(x) ~ a]
1111
@named sys = ODESystem(eqs)
12-
prob = ODEProblem(sys, [0, ], [0.0, 1.0],[])
13-
sol = solve(prob,Tsit5())
12+
prob = ODEProblem(sys, [0, ], [0.0, 1.0], [])
13+
sol = solve(prob, Tsit5())
1414

15-
# Test structural_simplify handling
16-
eqs = [D(x) ~ t,
15+
# Test structural_simplify substitutions & observed values
16+
eqs = [D(x) ~ 1,
1717
w ~ a]
1818
@named sys = ODESystem(eqs)
1919
simp = structural_simplify(sys);
20-
@test isequal(simp.substitutions.subs[1], w~a)
20+
@test isequal(simp.substitutions.subs[1], eqs[2])
21+
@test isequal(equations(simp)[1], eqs[1])
22+
prob = ODEProblem(simp, [0, ], [0.0, 1.0], [])
23+
sol = solve(prob, Tsit5())
24+
@test sol[w][1] == 1
2125

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ println("Last test requires gcc available in the path!")
4848
@safetestset "FuncAffect Test" begin include("funcaffect.jl") end
4949

5050
# Reference tests go Last
51-
#@safetestset "Latexify recipes Test" begin include("latexify.jl") end
51+
@safetestset "Latexify recipes Test" begin include("latexify.jl") end

0 commit comments

Comments
 (0)