Skip to content

Commit cc5724b

Browse files
committed
Basic functionality is implemented.
1 parent d3b31f9 commit cc5724b

File tree

8 files changed

+63
-25
lines changed

8 files changed

+63
-25
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ using .BipartiteGraphs
115115

116116
include("variables.jl")
117117
include("parameters.jl")
118+
include("constants.jl")
118119

119120
include("utils.jl")
120121
include("domains.jl")
@@ -206,7 +207,7 @@ export toexpr, get_variables
206207
export simplify, substitute
207208
export build_function
208209
export modelingtoolkitize
209-
export @variables, @parameters
210+
export @variables, @parameters, @constants
210211
export @named, @nonamespace, @namespace, extend, compose
211212

212213
end # module

src/constants.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,23 @@ function isconstant(x)
55
x = unwrap(x)
66
x isa Symbolic && getmetadata(x, MTKConstantCtx, false)
77
end
8-
8+
isconstant(x::Num) = isconstant(unwrap(x))
99
"""
1010
toconst(s::Sym)
1111
1212
Maps the parameter to a constant. The parameter must have a default.
1313
"""
14-
function toconst(s)
14+
function toconstant(s)
1515
if s isa Symbolics.Arr
16-
Symbolics.wrap(toconst(Symbolics.unwrap(s)))
16+
Symbolics.wrap(toconstant(Symbolics.unwrap(s)))
1717
elseif s isa AbstractArray
18-
map(toconst, s)
18+
map(toconstant, s)
1919
else
20-
assert(hasmetadata(s,VariableDefaultValue))
21-
setmetadata(s, MTKConstCtx, true)
20+
hasmetadata(s, Symbolics.VariableDefaultValue) || throw(ArgumentError("Constant `$(s)` must be assigned a default value."))
21+
setmetadata(s, MTKConstantCtx, true)
2222
end
2323
end
24-
toconst(s::Num) = wrap(toconst(value(s)))
24+
toconstant(s::Num) = wrap(toconstant(value(s)))
2525

2626
"""
2727
$(SIGNATURES)
@@ -32,5 +32,5 @@ macro constants(xs...)
3232
Symbolics._parse_vars(:constants,
3333
Real,
3434
xs,
35-
toconst) |> esc
35+
toconstant) |> esc
3636
end

src/systems/abstractsystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ for prop in [:eqs
167167
:iv
168168
:states
169169
:ps
170+
:cs
170171
:var_to_name
171172
:ctrls
172173
:defaults
@@ -376,6 +377,7 @@ end
376377
namespace_variables(sys::AbstractSystem) = states(sys, states(sys))
377378
namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
378379
namespace_controls(sys::AbstractSystem) = controls(sys, controls(sys))
380+
namespace_constants(sys::AbstractSystem) = constants(sys, constants(sys))
379381

380382
function namespace_defaults(sys)
381383
defs = defaults(sys)
@@ -437,6 +439,12 @@ function parameters(sys::AbstractSystem)
437439
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
438440
end
439441

442+
function constants(sys::AbstractSystem)
443+
cs = get_cs(sys)
444+
systems = get_systems(sys)
445+
unique(isempty(systems) ? cs : [cs; reduce(vcat, namespace_constants.(systems))])
446+
end
447+
440448
function controls(sys::AbstractSystem)
441449
ctrls = get_ctrls(sys)
442450
systems = get_systems(sys)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
127127
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
128128
[eq.rhs for eq in eqs]
129129

130+
# Swap constants for their values
131+
cs = constants(sys)
132+
if !isempty(cs) > 0
133+
cmap = map(x -> x => getdefault(x), cs)
134+
rhss = map(x -> substitute(x, cmap), rhss)
135+
end
136+
130137
# TODO: add an optional check on the ordering of observed equations
131138
u = map(x -> time_varying_as_func(value(x), sys), dvs)
132139
p = map(x -> time_varying_as_func(value(x), sys), ps)

src/systems/diffeqs/odesystem.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct ODESystem <: AbstractODESystem
3636
states::Vector
3737
"""Parameter variables. Must not contain the independent variable."""
3838
ps::Vector
39+
"""Symbolic constants."""
40+
cs::Vector
3941
"""Array variables."""
4042
var_to_name::Any
4143
"""Control parameters (some subset of `ps`)."""
@@ -120,7 +122,7 @@ struct ODESystem <: AbstractODESystem
120122
"""
121123
metadata::Any
122124

123-
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
125+
function ODESystem(deqs, iv, dvs, ps, cs, var_to_name, ctrls, observed, tgrad,
124126
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
125127
torn_matching, connector_type, connections, preface, cevents,
126128
devents, tearing_state = nothing, substitutions = nothing,
@@ -133,16 +135,16 @@ struct ODESystem <: AbstractODESystem
133135
check_equations(equations(cevents), iv)
134136
end
135137
if checks == true || (checks & CheckUnits) > 0
136-
all_dimensionless([dvs; ps; iv]) || check_units(deqs)
138+
all_dimensionless([dvs; ps; iv; cs]) || check_units(deqs)
137139
end
138-
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
140+
new(deqs, iv, dvs, ps, cs, var_to_name, ctrls, observed, tgrad, jac,
139141
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
140142
connector_type, connections, preface, cevents, devents, tearing_state,
141143
substitutions, metadata)
142144
end
143145
end
144146

145-
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
147+
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps, cs;
146148
controls = Num[],
147149
observed = Equation[],
148150
systems = ODESystem[],
@@ -164,6 +166,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
164166
iv′ = value(scalarize(iv))
165167
dvs′ = value.(scalarize(dvs))
166168
ps′ = value.(scalarize(ps))
169+
cs′ = value.(scalarize(cs))
167170
ctrl′ = value.(scalarize(controls))
168171

169172
if !(isempty(default_u0) && isempty(default_p))
@@ -176,6 +179,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
176179
var_to_name = Dict()
177180
process_variables!(var_to_name, defaults, dvs′)
178181
process_variables!(var_to_name, defaults, ps′)
182+
process_variables!(var_to_name, defaults, cs′)
179183
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
180184

181185
tgrad = RefValue(EMPTY_TGRAD)
@@ -189,7 +193,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
189193
end
190194
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
191195
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
192-
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
196+
ODESystem(deqs, iv′, dvs′, ps′, cs′, var_to_name, ctrl′, observed, tgrad, jac,
193197
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
194198
connector_type, nothing, preface, cont_callbacks, disc_callbacks,
195199
metadata, checks = checks)
@@ -201,6 +205,7 @@ function ODESystem(eqs, iv = nothing; kwargs...)
201205
diffvars = OrderedSet()
202206
allstates = OrderedSet()
203207
ps = OrderedSet()
208+
cs = OrderedSet() #Constants
204209
# reorder equations such that it is in the form of `diffeq, algeeq`
205210
diffeq = Equation[]
206211
algeeq = Equation[]
@@ -218,8 +223,8 @@ function ODESystem(eqs, iv = nothing; kwargs...)
218223
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
219224
for eq in eqs
220225
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
221-
collect_vars!(allstates, ps, eq.lhs, iv)
222-
collect_vars!(allstates, ps, eq.rhs, iv)
226+
collect_vars!(allstates, ps, cs, eq.lhs, iv)
227+
collect_vars!(allstates, ps, cs, eq.rhs, iv)
223228
if isdiffeq(eq)
224229
diffvar, _ = var_from_nested_derivative(eq.lhs)
225230
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
@@ -235,7 +240,7 @@ function ODESystem(eqs, iv = nothing; kwargs...)
235240
algevars = setdiff(allstates, diffvars)
236241
# the orders here are very important!
237242
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
238-
collect(Iterators.flatten((diffvars, algevars))), ps; kwargs...)
243+
collect(Iterators.flatten((diffvars, algevars))), ps, cs; kwargs...)
239244
end
240245

241246
# NOTE: equality does not check cached Jacobian
@@ -260,6 +265,7 @@ function flatten(sys::ODESystem, noeqs = false)
260265
get_iv(sys),
261266
states(sys),
262267
parameters(sys),
268+
constants(sys),
263269
observed = observed(sys),
264270
continuous_events = continuous_events(sys),
265271
discrete_events = discrete_events(sys),

src/utils.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -466,38 +466,40 @@ function find_derivatives!(vars, expr, f)
466466
return vars
467467
end
468468

469-
function collect_vars!(states, parameters, expr, iv)
469+
function collect_vars!(states, parameters, constants, expr, iv)
470470
if expr isa Sym
471-
collect_var!(states, parameters, expr, iv)
471+
collect_var!(states, parameters, constants, expr, iv)
472472
else
473473
for var in vars(expr)
474474
if istree(var) && operation(var) isa Differential
475475
var, _ = var_from_nested_derivative(var)
476476
end
477-
collect_var!(states, parameters, var, iv)
477+
collect_var!(states, parameters, constants, var, iv)
478478
end
479479
end
480480
return nothing
481481
end
482482

483-
function collect_vars_difference!(states, parameters, expr, iv)
483+
function collect_vars_difference!(states, parameters, constants, expr, iv)
484484
if expr isa Sym
485-
collect_var!(states, parameters, expr, iv)
485+
collect_var!(states, parameters, constants, expr, iv)
486486
else
487487
for var in vars(expr)
488488
if istree(var) && operation(var) isa Difference
489489
var, _ = var_from_nested_difference(var)
490490
end
491-
collect_var!(states, parameters, var, iv)
491+
collect_var!(states, parameters, constants, var, iv)
492492
end
493493
end
494494
return nothing
495495
end
496496

497-
function collect_var!(states, parameters, var, iv)
497+
function collect_var!(states, parameters, constants, var, iv)
498498
isequal(var, iv) && return nothing
499499
if isparameter(var) || (istree(var) && isparameter(operation(var)))
500500
push!(parameters, var)
501+
elseif isconstant(var)
502+
push!(constants,var)
501503
else
502504
push!(states, var)
503505
end

test/constants.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using ModelingToolkit, OrdinaryDiffEq
2+
using Test
3+
MT = ModelingToolkit
4+
5+
@constants a = 1
6+
@test_throws MT.MissingDefaultError @constants b
7+
8+
@variables t x(t) w(t)
9+
D = Differential(t)
10+
eqs = [D(x) ~ a]
11+
@named sys = ODESystem(eqs)
12+
prob = ODEProblem(sys, [0, ], [0.0, 1.0],[])
13+
sol = solve(prob,Tsit5())
14+

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ println("Last test requires gcc available in the path!")
4848
@safetestset "FuncAffect Test" begin include("funcaffect.jl") end
4949

5050
# Reference tests go Last
51-
@safetestset "Latexify recipes Test" begin include("latexify.jl") end
51+
#@safetestset "Latexify recipes Test" begin include("latexify.jl") end

0 commit comments

Comments
 (0)