Skip to content

Commit 01f1412

Browse files
authored
Merge pull request #1821 from lamorton/constants2
Adding @Constants
2 parents 7a80a24 + 80332f7 commit 01f1412

26 files changed

+391
-97
lines changed

docs/src/basics/ContextualVariables.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ All modeling projects have some form of parameters. `@parameters` marks a variab
2020
as being the parameter of some system, which allows automatic detection algorithms
2121
to ignore such variables when attempting to find the states of a system.
2222

23+
## Constants
24+
25+
Constants are like parameters that:
26+
- always have a default value, which must be assigned when the constants are
27+
declared
28+
- do not show up in the list of parameters of a system.
29+
30+
The intended use-cases for constants are:
31+
- representing literals (eg, π) symbolically, which results in cleaner
32+
Latexification of equations (avoids turning `d ~ 2π*r` into `d = 6.283185307179586 r`)
33+
- allowing auto-generated unit conversion factors to live outside the list of
34+
parameters
35+
- representing fundamental constants (eg, speed of light `c`) that should never
36+
be adjusted inadvertently.
37+
2338
## Wildcard Variable Arguments
2439

2540
```julia
@@ -28,7 +43,7 @@ to ignore such variables when attempting to find the states of a system.
2843

2944
It is possible to define a dependent variable which is an open function as above,
3045
for which its arguments must be specified each time it is used. This is useful with
31-
PDEs for example, where one may need to use `u(t, x)` in the equations, but will
46+
PDEs for example, where one may need to use `u(t, x)` in the equations, but will
3247
need to be able to write `u(t, 0.0)` to define a boundary condition at `x = 0`.
3348

3449
## Variable metadata [Experimental/TODO]

docs/src/tutorials/ode_modeling.md

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@ But if you want to just see some code and run, here's an example:
1414
using ModelingToolkit
1515

1616
@variables t x(t) # independent and dependent variables
17-
@parameters τ # parameters
17+
@parameters τ # parameters
18+
@constants h = 1 # constants have an assigned value
1819
D = Differential(t) # define an operator for the differentiation w.r.t. time
1920

2021
# your first ODE, consisting of a single equation, the equality indicated by ~
21-
@named fol = ODESystem([ D(x) ~ (1 - x)/τ])
22+
@named fol = ODESystem([ D(x) ~ (h - x)/τ])
2223

2324
using DifferentialEquations: solve
2425
using Plots: plot
2526

2627
prob = ODEProblem(fol, [x => 0.0], (0.0,10.0), [τ => 3.0])
28+
# parameter `τ` can be assigned a value, but constant `h` cannot
2729
sol = solve(prob)
2830
plot(sol)
2931
```
@@ -42,19 +44,20 @@ first-order lag element:
4244
```
4345

4446
Here, ``t`` is the independent variable (time), ``x(t)`` is the (scalar) state
45-
variable, ``f(t)`` is an external forcing function, and ``\tau`` is a constant
47+
variable, ``f(t)`` is an external forcing function, and ``\tau`` is a
4648
parameter. In MTK, this system can be modelled as follows. For simplicity, we
47-
first set the forcing function to a constant value.
49+
first set the forcing function to a time-independent value.
4850

4951
```julia
5052
using ModelingToolkit
5153

5254
@variables t x(t) # independent and dependent variables
5355
@parameters τ # parameters
56+
@constants h = 1 # constants
5457
D = Differential(t) # define an operator for the differentiation w.r.t. time
5558

5659
# your first ODE, consisting of a single equation, indicated by ~
57-
@named fol_model = ODESystem(D(x) ~ (1 - x)/τ)
60+
@named fol_model = ODESystem(D(x) ~ (h - x)/τ)
5861
# Model fol_model with 1 equations
5962
# States (1):
6063
# x(t)
@@ -89,7 +92,7 @@ intermediate variable `RHS`:
8992

9093
```julia
9194
@variables RHS(t)
92-
@named fol_separate = ODESystem([ RHS ~ (1 - x)/τ,
95+
@named fol_separate = ODESystem([ RHS ~ (h - x)/τ,
9396
D(x) ~ RHS ])
9497
# Model fol_separate with 2 equations
9598
# States (2):
@@ -110,7 +113,7 @@ fol_simplified = structural_simplify(fol_separate)
110113

111114
equations(fol_simplified)
112115
# 1-element Array{Equation,1}:
113-
# Differential(t)(x(t)) ~ (τ^-1)*(1 - x(t))
116+
# Differential(t)(x(t)) ~ (τ^-1)*(h - x(t))
114117

115118
equations(fol_simplified) == equations(fol_model)
116119
# true
@@ -133,6 +136,12 @@ sol = solve(prob)
133136
plot(sol, vars=[x, RHS])
134137
```
135138

139+
By default, `structural_simplify` also replaces symbolic `constants` with
140+
their default values. This allows additional simplifications not possible
141+
if using `parameters` (eg, solution of linear equations by dividing out
142+
the constant's value, which cannot be done for parameters since they may
143+
be zero).
144+
136145
![Simulation result of first-order lag element, with right-hand side](https://user-images.githubusercontent.com/13935112/111958403-7e8d3e00-8aed-11eb-9d18-08b5180a59f9.png)
137146

138147
Note that similarly the indexing of the solution works via the names, and so

src/ModelingToolkit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ using .BipartiteGraphs
118118

119119
include("variables.jl")
120120
include("parameters.jl")
121+
include("constants.jl")
121122

122123
include("utils.jl")
123124
include("domains.jl")
@@ -214,7 +215,8 @@ export toexpr, get_variables
214215
export simplify, substitute
215216
export build_function
216217
export modelingtoolkitize
217-
export @variables, @parameters
218+
219+
export @variables, @parameters, @constants
218220
export @named, @nonamespace, @namespace, extend, compose, complete
219221
export debug_system
220222

src/constants.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import SymbolicUtils: symtype, term, hasmetadata, issym
2+
struct MTKConstantCtx end
3+
4+
isconstant(x::Num) = isconstant(unwrap(x))
5+
""" Test whether `x` is a constant-type Sym. """
6+
function isconstant(x)
7+
x = unwrap(x)
8+
x isa Symbolic && getmetadata(x, MTKConstantCtx, false)
9+
end
10+
11+
"""
12+
toconstant(s::Sym)
13+
14+
Maps the parameter to a constant. The parameter must have a default.
15+
"""
16+
function toconstant(s::Sym)
17+
hasmetadata(s, Symbolics.VariableDefaultValue) ||
18+
throw(ArgumentError("Constant `$(s)` must be assigned a default value."))
19+
setmetadata(s, MTKConstantCtx, true)
20+
end
21+
22+
toconstant(s::Num) = wrap(toconstant(value(s)))
23+
24+
"""
25+
$(SIGNATURES)
26+
27+
Define one or more constants.
28+
"""
29+
macro constants(xs...)
30+
Symbolics._parse_vars(:constants,
31+
Real,
32+
xs,
33+
toconstant) |> esc
34+
end

src/structural_transformation/codegen.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra
22

3-
using ModelingToolkit: isdifferenceeq, process_events
3+
using ModelingToolkit: isdifferenceeq, process_events, get_preprocess_constants
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

@@ -187,12 +187,14 @@ function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDic
187187

188188
fname = gensym("fun")
189189
# f is the function to find roots on
190+
funex = isscalar ? rhss[1] : MakeArray(rhss, SVector)
191+
pre = get_preprocess_constants(funex)
190192
f = Func([DestructuredArgs(vars, inbounds = !checkbounds)
191193
DestructuredArgs(params, inbounds = !checkbounds)],
192194
[],
193-
Let(needed_assignments[inner_idxs],
194-
isscalar ? rhss[1] : MakeArray(rhss, SVector),
195-
false)) |> SymbolicUtils.Code.toexpr
195+
pre(Let(needed_assignments[inner_idxs],
196+
funex,
197+
false))) |> SymbolicUtils.Code.toexpr
196198

197199
# solver call contains code to call the root-finding solver on the function f
198200
solver_call = LiteralExpr(quote
@@ -294,6 +296,8 @@ function build_torn_function(sys;
294296
syms = map(Symbol, states)
295297

296298
pre = get_postprocess_fbody(sys)
299+
cpre = get_preprocess_constants(rhss)
300+
pre2 = x -> pre(cpre(x))
297301

298302
expr = SymbolicUtils.Code.toexpr(Func([out
299303
DestructuredArgs(states,
@@ -302,10 +306,10 @@ function build_torn_function(sys;
302306
inbounds = !checkbounds)
303307
independent_variables(sys)],
304308
[],
305-
pre(Let([torn_expr;
306-
assignments[is_not_prepended_assignment]],
307-
funbody,
308-
false))),
309+
pre2(Let([torn_expr;
310+
assignments[is_not_prepended_assignment]],
311+
funbody,
312+
false))),
309313
sol_states)
310314
if expression
311315
expr, states
@@ -477,17 +481,19 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs,
477481
push!(subs, sym obs[eqidx].rhs)
478482
end
479483
pre = get_postprocess_fbody(sys)
480-
484+
cpre = get_preprocess_constants([obs[1:maxidx];
485+
isscalar ? ts[1] : MakeArray(ts, output_type)])
486+
pre2 = x -> pre(cpre(x))
481487
ex = Code.toexpr(Func([DestructuredArgs(solver_states, inbounds = !checkbounds)
482488
DestructuredArgs(parameters(sys), inbounds = !checkbounds)
483489
independent_variables(sys)],
484490
[],
485-
pre(Let([collect(Iterators.flatten(solves))
486-
assignments[is_not_prepended_assignment]
487-
map(eq -> eq.lhs eq.rhs, obs[1:maxidx])
488-
subs],
489-
isscalar ? ts[1] : MakeArray(ts, output_type),
490-
false))), sol_states)
491+
pre2(Let([collect(Iterators.flatten(solves))
492+
assignments[is_not_prepended_assignment]
493+
map(eq -> eq.lhs eq.rhs, obs[1:maxidx])
494+
subs],
495+
isscalar ? ts[1] : MakeArray(ts, output_type),
496+
false))), sol_states)
491497

492498
expression ? ex : @RuntimeGeneratedFunction(ex)
493499
end

src/structural_transformation/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
179179
a, b, islinear = linear_expansion(term, var)
180180
a, b = unwrap(a), unwrap(b)
181181
islinear || (all_int_vars = false; continue)
182+
a = ModelingToolkit.fold_constants(a)
183+
b = ModelingToolkit.fold_constants(b)
182184
if a isa Symbolic
183185
all_int_vars = false
184186
if !allow_symbolic

src/systems/abstractsystem.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,8 @@ The optional argument `io` may take a tuple `(inputs, outputs)`.
10301030
This will convert all `inputs` to parameters and allow them to be unconnected, i.e.,
10311031
simplification will allow models where `n_states = n_equations - n_inputs`.
10321032
"""
1033-
function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false, kwargs...)
1033+
function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
1034+
simplify_constants = true, kwargs...)
10341035
sys = expand_connections(sys)
10351036
sys isa DiscreteSystem && return sys
10361037
state = TearingState(sys)
@@ -1046,6 +1047,18 @@ function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false
10461047
return has_io ? (sys, input_idxs) : sys
10471048
end
10481049

1050+
function eliminate_constants(sys::AbstractSystem)
1051+
if has_eqs(sys)
1052+
eqs = get_eqs(sys)
1053+
eq_cs = collect_constants(eqs)
1054+
if !isempty(eq_cs)
1055+
new_eqs = eliminate_constants(eqs, eq_cs)
1056+
@set! sys.eqs = new_eqs
1057+
end
1058+
end
1059+
return sys
1060+
end
1061+
10491062
function io_preprocessing(sys::AbstractSystem, inputs,
10501063
outputs; simplify = false, kwargs...)
10511064
sys, input_idxs = structural_simplify(sys, (; inputs, outputs); simplify, kwargs...)

src/systems/callbacks.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,13 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
269269
p = map(x -> time_varying_as_func(value(x), sys), ps)
270270
t = get_iv(sys)
271271
condit = condition(cb)
272-
build_function(condit, u, t, p; expression, wrap_code = condition_header(), kwargs...)
272+
cs = collect_constants(condit)
273+
if !isempty(cs)
274+
cmap = map(x -> x => getdefault(x), cs)
275+
condit = substitute(condit, cmap)
276+
end
277+
build_function(condit, u, t, p; expression, wrap_code = condition_header(),
278+
kwargs...)
273279
end
274280

275281
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
@@ -337,9 +343,11 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
337343
t = get_iv(sys)
338344
integ = gensym(:MTKIntegrator)
339345
getexpr = (postprocess_affect_expr! === nothing) ? expression : Val{true}
346+
pre = get_preprocess_constants(rhss)
340347
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = getexpr,
341348
wrap_code = add_integrator_header(integ, outvar),
342349
outputidxs = update_inds,
350+
postprocess_fbody = pre,
343351
kwargs...)
344352
# applied user-provided function to the generated expression
345353
if postprocess_affect_expr! !== nothing
@@ -376,7 +384,9 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = states
376384
u = map(x -> time_varying_as_func(value(x), sys), dvs)
377385
p = map(x -> time_varying_as_func(value(x), sys), ps)
378386
t = get_iv(sys)
379-
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false}, kwargs...)
387+
pre = get_preprocess_constants(rhss)
388+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false},
389+
postprocess_fbody = pre, kwargs...)
380390

381391
affect_functions = map(cbs) do cb # Keep affect function separate
382392
eq_aff = affects(cb)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,15 @@ end
8383
function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
8484
simplify = false, kwargs...)
8585
tgrad = calculate_tgrad(sys, simplify = simplify)
86-
return build_function(tgrad, dvs, ps, get_iv(sys); kwargs...)
86+
pre = get_preprocess_constants(tgrad)
87+
return build_function(tgrad, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
8788
end
8889

8990
function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
9091
simplify = false, sparse = false, kwargs...)
9192
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
92-
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
93+
pre = get_preprocess_constants(jac)
94+
return build_function(jac, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
9395
end
9496

9597
function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys),
@@ -109,7 +111,9 @@ function generate_dae_jacobian(sys::AbstractODESystem, dvs = states(sys),
109111
dvs = states(sys)
110112
@variables ˍ₋gamma
111113
jac = ˍ₋gamma * jac_du + jac_u
112-
return build_function(jac, derivatives, dvs, ps, ˍ₋gamma, get_iv(sys); kwargs...)
114+
pre = get_preprocess_constants(jac)
115+
return build_function(jac, derivatives, dvs, ps, ˍ₋gamma, get_iv(sys);
116+
postprocess_fbody = pre, kwargs...)
113117
end
114118

115119
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
@@ -163,8 +167,10 @@ function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = paramete
163167
end
164168

165169
pre = get_postprocess_fbody(sys)
170+
cpre = get_preprocess_constants(body)
171+
pre2 = x -> pre(cpre(x))
166172
f_oop, f_iip = build_function(body, u, p, t; expression = Val{false},
167-
postprocess_fbody = pre, kwargs...)
173+
postprocess_fbody = pre2, kwargs...)
168174

169175
cb_affect! = let f_oop = f_oop, f_iip = f_iip
170176
function cb_affect!(integ)

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ end
206206

207207
function ODESystem(eqs, iv = nothing; kwargs...)
208208
eqs = scalarize(eqs)
209-
# NOTE: this assumes that the order of algebric equations doesn't matter
209+
# NOTE: this assumes that the order of algebraic equations doesn't matter
210210
diffvars = OrderedSet()
211211
allstates = OrderedSet()
212212
ps = OrderedSet()
@@ -301,6 +301,13 @@ function build_explicit_observed_function(sys, ts;
301301
dep_vars = scalarize(setdiff(vars, ivs))
302302

303303
obs = observed(sys)
304+
305+
cs = collect_constants(obs)
306+
if !isempty(cs) > 0
307+
cmap = map(x -> x => getdefault(x), cs)
308+
obs = map(x -> x.lhs ~ substitute(x.rhs, cmap), obs)
309+
end
310+
304311
sts = Set(states(sys))
305312
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
306313
namespaced_to_obs = Dict(states(sys, x.lhs) => x.lhs for x in obs)

0 commit comments

Comments
 (0)