Skip to content

Commit c651efe

Browse files
committed
Merge remote-tracking branch 'origin' into ss_discrete
merge master
2 parents 66266e4 + 0a5c1ce commit c651efe

20 files changed

+503
-60
lines changed

docs/src/basics/Debugging.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,39 @@ dsol = solve(dprob, Tsit5());
3535
Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `` function.
3636
We could have figured that out ourselves, but it is not always so obvious for more complex models.
3737

38+
Suppose we also want to validate that `u1 + u2 >= 2.0`. We can do this via the assertions functionality.
39+
40+
```@example debug
41+
@mtkbuild sys = ODESystem(eqs, t; defaults, assertions = [(u1 + u2 >= 2.0) => "Oh no!"])
42+
```
43+
44+
The assertions must be an iterable of pairs, where the first element is the symbolic condition and
45+
the second is a message to be logged when the condition fails. All assertions are added to the
46+
generated code and will cause the solver to reject steps that fail the assertions. For systems such
47+
as the above where the assertion is guaranteed to eventually fail, the solver will likely exit
48+
with a `dtmin` failure..
49+
50+
```@example debug
51+
prob = ODEProblem(sys, [], (0.0, 10.0))
52+
sol = solve(prob, Tsit5())
53+
```
54+
55+
We can use `debug_system` to log the failing assertions in each call to the RHS function.
56+
57+
```@repl debug
58+
dsys = debug_system(sys; functions = []);
59+
dprob = ODEProblem(dsys, [], (0.0, 10.0));
60+
dsol = solve(dprob, Tsit5());
61+
```
62+
63+
Note the logs containing the failed assertion and corresponding message. To temporarily disable
64+
logging in a system returned from `debug_system`, use `ModelingToolkit.ASSERTION_LOG_VARIABLE`.
65+
66+
```@repl debug
67+
dprob[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false;
68+
solve(drob, Tsit5());
69+
```
70+
3871
```@docs
3972
debug_system
4073
```

docs/src/basics/Variable_metadata.md

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
5454
5555
@variables i(t) [connect = Flow]
5656
@variables k(t) [connect = Stream]
57+
hasconnect(i)
58+
```
59+
60+
```@example connect
61+
getconnect(k)
5762
```
5863

5964
## Input or output
@@ -177,8 +182,45 @@ A variable can be marked `irreducible` to prevent it from being moved to an
177182
`observed` state. This forces the variable to be computed during solving so that
178183
it can be accessed in [callbacks](@ref events)
179184

180-
```julia
181-
@variable important_value [irreducible = true]
185+
```@example metadata
186+
@variables important_value [irreducible = true]
187+
isirreducible(important_value)
188+
```
189+
190+
## State Priority
191+
192+
When a model is structurally simplified, the algorithm will try to ensure that the variables with higher state priority become states of the system. A variable's state priority is a number set using the `state_priority` metadata.
193+
194+
```@example metadata
195+
@variables important_dof [state_priority = 10] unimportant_dof [state_priority = -2]
196+
state_priority(important_dof)
197+
```
198+
199+
## Units
200+
201+
Units for variables can be designated using symbolic metadata. For more information, please see the [model validation and units](@ref units) section of the docs. Note that `getunit` is not equivalent to `get_unit` - the former is a metadata getter for individual variables (and is provided so the same interface function for `unit` exists like other metadata), while the latter is used to handle more general symbolic expressions.
202+
203+
```@example metadata
204+
@variables speed [unit = u"m/s"]
205+
hasunit(speed)
206+
```
207+
208+
```@example metadata
209+
getunit(speed)
210+
```
211+
212+
## Miscellaneous metadata
213+
214+
User-defined metadata can be added using the `misc` metadata. This can be queried
215+
using the `hasmisc` and `getmisc` functions.
216+
217+
```@example metadata
218+
@variables u [misc = :conserved_parameter] y [misc = [2, 4, 6]]
219+
hasmisc(u)
220+
```
221+
222+
```@example metadata
223+
getmisc(y)
182224
```
183225

184226
## Additional functions

docs/src/tutorials/initialization.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ principles of initialization of DAE systems. Take a DAE written in semi-explicit
1414
form:
1515

1616
```math
17-
x' = f(x,y,t)\\
18-
0 = g(x,y,t)
17+
\begin{aligned}
18+
x^\prime &= f(x,y,t) \\
19+
0 &= g(x,y,t)
20+
\end{aligned}
1921
```
2022

2123
where ``x`` are the differential variables and ``y`` are the algebraic variables.

src/ModelingToolkit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ export initial_state, transition, activeState, entry, ticksInState, timeInState
247247
export @component, @mtkmodel, @mtkbuild
248248
export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbance,
249249
istunable, getdist, hasdist,
250-
tunable_parameters, isirreducible, getdescription, hasdescription
250+
tunable_parameters, isirreducible, getdescription, hasdescription,
251+
hasunit, getunit, hasconnect, getconnect,
252+
hasmisc, getmisc, state_priority
251253
export ode_order_lowering, dae_order_lowering, liouville_transform
252254
export PDESystem
253255
export Differential, expand_derivatives, @derivatives

src/debugging.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,59 @@ function debug_sub(ex, funcs; kw...)
4242
f in funcs ? logged_fun(f, args...; kw...) :
4343
maketerm(typeof(ex), f, args, metadata(ex))
4444
end
45+
46+
"""
47+
$(TYPEDSIGNATURES)
48+
49+
A function which returns `NaN` if `condition` fails, and `0.0` otherwise.
50+
"""
51+
function _nan_condition(condition::Bool)
52+
condition ? 0.0 : NaN
53+
end
54+
55+
@register_symbolic _nan_condition(condition::Bool)
56+
57+
"""
58+
$(TYPEDSIGNATURES)
59+
60+
A function which takes a condition `expr` and returns `NaN` if it is false,
61+
and zero if it is true. In case the condition is false and `log == true`,
62+
`message` will be logged as an `@error`.
63+
"""
64+
function _debug_assertion(expr::Bool, message::String, log::Bool)
65+
value = _nan_condition(expr)
66+
isnan(value) || return value
67+
log && @error message
68+
return value
69+
end
70+
71+
@register_symbolic _debug_assertion(expr::Bool, message::String, log::Bool)
72+
73+
"""
74+
Boolean parameter added to models returned from `debug_system` to control logging of
75+
assertions.
76+
"""
77+
const ASSERTION_LOG_VARIABLE = only(@parameters __log_assertions_ₘₜₖ::Bool = false)
78+
79+
"""
80+
$(TYPEDSIGNATURES)
81+
82+
Get a symbolic expression for all the assertions in `sys`. The expression returns `NaN`
83+
if any of the assertions fail, and `0.0` otherwise. If `ASSERTION_LOG_VARIABLE` is a
84+
parameter in the system, it will control whether the message associated with each
85+
assertion is logged when it fails.
86+
"""
87+
function get_assertions_expr(sys::AbstractSystem)
88+
asserts = assertions(sys)
89+
term = 0
90+
if is_parameter(sys, ASSERTION_LOG_VARIABLE)
91+
for (k, v) in asserts
92+
term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE)
93+
end
94+
else
95+
for (k, v) in asserts
96+
term += _nan_condition(k)
97+
end
98+
end
99+
return term
100+
end

src/systems/abstractsystem.jl

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ for prop in [:eqs
983983
:gui_metadata
984984
:discrete_subsystems
985985
:parameter_dependencies
986+
:assertions
986987
:solved_unknowns
987988
:split_idxs
988989
:parent
@@ -1468,6 +1469,24 @@ end
14681469
"""
14691470
$(TYPEDSIGNATURES)
14701471
1472+
Get the assertions for a system `sys` and its subsystems.
1473+
"""
1474+
function assertions(sys::AbstractSystem)
1475+
has_assertions(sys) || return Dict{BasicSymbolic, String}()
1476+
1477+
asserts = get_assertions(sys)
1478+
systems = get_systems(sys)
1479+
namespaced_asserts = mapreduce(
1480+
merge!, systems; init = Dict{BasicSymbolic, String}()) do subsys
1481+
Dict{BasicSymbolic, String}(namespace_expr(k, subsys) => v
1482+
for (k, v) in assertions(subsys))
1483+
end
1484+
return merge(asserts, namespaced_asserts)
1485+
end
1486+
1487+
"""
1488+
$(TYPEDSIGNATURES)
1489+
14711490
Get the guesses for variables in the initialization system of the system `sys` and its subsystems.
14721491
14731492
See also [`initialization_equations`](@ref) and [`ModelingToolkit.get_guesses`](@ref).
@@ -2248,15 +2267,20 @@ macro mtkbuild(exprs...)
22482267
expr = exprs[1]
22492268
named_expr = ModelingToolkit.named_expr(expr)
22502269
name = named_expr.args[1]
2251-
kwargs = if length(exprs) > 1
2252-
NamedTuple{Tuple(ex.args[1] for ex in Base.tail(exprs))}(Tuple(ex.args[2]
2253-
for ex in Base.tail(exprs)))
2270+
kwargs = Base.tail(exprs)
2271+
kwargs = map(kwargs) do ex
2272+
@assert ex.head == :(=)
2273+
Expr(:kw, ex.args[1], ex.args[2])
2274+
end
2275+
if isempty(kwargs)
2276+
kwargs = ()
22542277
else
2255-
(;)
2278+
kwargs = (Expr(:parameters, kwargs...),)
22562279
end
2280+
call_expr = Expr(:call, structural_simplify, kwargs..., name)
22572281
esc(quote
22582282
$named_expr
2259-
$name = $structural_simplify($name; $(kwargs)...)
2283+
$name = $call_expr
22602284
end)
22612285
end
22622286

@@ -2278,6 +2302,13 @@ ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input
22782302
1 => 1
22792303
sin(P(t)) => 0.0
22802304
```
2305+
2306+
Additionally, all assertions in the system are optionally logged when they fail.
2307+
A new parameter is also added to the system which controls whether the message associated
2308+
with each assertion will be logged when the assertion fails. This parameter defaults to
2309+
`true` and can be toggled by symbolic indexing with
2310+
`ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example,
2311+
`prob.ps[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false` will disable logging.
22812312
"""
22822313
function debug_system(
22832314
sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], kw...)
@@ -2288,11 +2319,17 @@ function debug_system(
22882319
error("debug_system(sys) only works on systems with no sub-systems! Consider flattening it with flatten(sys) or structural_simplify(sys) first.")
22892320
end
22902321
if has_eqs(sys)
2291-
@set! sys.eqs = debug_sub.(equations(sys), Ref(functions); kw...)
2322+
eqs = debug_sub.(equations(sys), Ref(functions); kw...)
2323+
@set! sys.eqs = eqs
2324+
@set! sys.ps = unique!([get_ps(sys); ASSERTION_LOG_VARIABLE])
2325+
@set! sys.defaults = merge(get_defaults(sys), Dict(ASSERTION_LOG_VARIABLE => true))
22922326
end
22932327
if has_observed(sys)
22942328
@set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...)
22952329
end
2330+
if iscomplete(sys)
2331+
sys = complete(sys; split = is_split(sys))
2332+
end
22962333
return sys
22972334
end
22982335

@@ -3031,6 +3068,11 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
30313068
kwargs = merge(kwargs, (initialization_eqs = ieqs, guesses = guesses))
30323069
end
30333070

3071+
if has_assertions(basesys)
3072+
kwargs = merge(
3073+
kwargs, (; assertions = merge(get_assertions(basesys), get_assertions(sys))))
3074+
end
3075+
30343076
return T(args...; kwargs...)
30353077
end
30363078

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
168168
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
169169
[eq.rhs for eq in eqs]
170170

171+
if !isempty(assertions(sys))
172+
rhss[end] += unwrap(get_assertions_expr(sys))
173+
end
174+
171175
# TODO: add an optional check on the ordering of observed equations
172176
u = dvs
173177
p = reorder_parameters(sys, ps)
@@ -542,7 +546,7 @@ Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
542546
The arguments `dvs` and `ps` are used to set the order of the dependent
543547
variable and parameter vectors, respectively.
544548
"""
545-
struct ODEFunctionExpr{iip} end
549+
struct ODEFunctionExpr{iip, specialize} end
546550

547551
struct ODEFunctionClosure{O, I} <: Function
548552
f_oop::O
@@ -551,7 +555,7 @@ end
551555
(f::ODEFunctionClosure)(u, p, t) = f.f_oop(u, p, t)
552556
(f::ODEFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t)
553557

554-
function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
558+
function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns(sys),
555559
ps = parameters(sys), u0 = nothing;
556560
version = nothing, tgrad = false,
557561
jac = false, p = nothing,
@@ -560,14 +564,12 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
560564
steady_state = false,
561565
sparsity = false,
562566
observedfun_exp = nothing,
563-
kwargs...) where {iip}
567+
kwargs...) where {iip, specialize}
564568
if !iscomplete(sys)
565569
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunctionExpr`")
566570
end
567571
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
568572

569-
dict = Dict()
570-
571573
fsym = gensym(:f)
572574
_f = :($fsym = $ODEFunctionClosure($f_oop, $f_iip))
573575
tgradsym = gensym(:tgrad)
@@ -590,30 +592,28 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
590592
_jac = :($jacsym = nothing)
591593
end
592594

595+
Msym = gensym(:M)
593596
M = calculate_massmatrix(sys)
594-
595-
_M = if sparse && !(u0 === nothing || M === I)
596-
SparseArrays.sparse(M)
597+
if sparse && !(u0 === nothing || M === I)
598+
_M = :($Msym = $(SparseArrays.sparse(M)))
597599
elseif u0 === nothing || M === I
598-
M
600+
_M = :($Msym = $M)
599601
else
600-
ArrayInterface.restructure(u0 .* u0', M)
602+
_M = :($Msym = $(ArrayInterface.restructure(u0 .* u0', M)))
601603
end
602604

603605
jp_expr = sparse ? :($similar($(get_jac(sys)[]), Float64)) : :nothing
604606
ex = quote
605-
$_f
606-
$_tgrad
607-
$_jac
608-
M = $_M
609-
ODEFunction{$iip}($fsym,
610-
sys = $sys,
611-
jac = $jacsym,
612-
tgrad = $tgradsym,
613-
mass_matrix = M,
614-
jac_prototype = $jp_expr,
615-
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing),
616-
observed = $observedfun_exp)
607+
let $_f, $_tgrad, $_jac, $_M
608+
ODEFunction{$iip, $specialize}($fsym,
609+
sys = $sys,
610+
jac = $jacsym,
611+
tgrad = $tgradsym,
612+
mass_matrix = $Msym,
613+
jac_prototype = $jp_expr,
614+
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing),
615+
observed = $observedfun_exp)
616+
end
617617
end
618618
!linenumbers ? Base.remove_linenums!(ex) : ex
619619
end
@@ -622,6 +622,14 @@ function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
622622
ODEFunctionExpr{true}(sys, args...; kwargs...)
623623
end
624624

625+
function ODEFunctionExpr{true}(sys::AbstractODESystem, args...; kwargs...)
626+
return ODEFunctionExpr{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
627+
end
628+
629+
function ODEFunctionExpr{false}(sys::AbstractODESystem, args...; kwargs...)
630+
return ODEFunctionExpr{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
631+
end
632+
625633
"""
626634
```julia
627635
DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),

0 commit comments

Comments
 (0)