Skip to content

Commit 9b492cd

Browse files
committed
Merge remote-tracking branch 'vyudu/BVP-with-constraints' into BVP-with-constraints
2 parents 6740b8c + b10a4a6 commit 9b492cd

20 files changed

+506
-63
lines changed

docs/src/basics/Debugging.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,39 @@ dsol = solve(dprob, Tsit5());
3535
Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `` function.
3636
We could have figured that out ourselves, but it is not always so obvious for more complex models.
3737

38+
Suppose we also want to validate that `u1 + u2 >= 2.0`. We can do this via the assertions functionality.
39+
40+
```@example debug
41+
@mtkbuild sys = ODESystem(eqs, t; defaults, assertions = [(u1 + u2 >= 2.0) => "Oh no!"])
42+
```
43+
44+
The assertions must be an iterable of pairs, where the first element is the symbolic condition and
45+
the second is a message to be logged when the condition fails. All assertions are added to the
46+
generated code and will cause the solver to reject steps that fail the assertions. For systems such
47+
as the above where the assertion is guaranteed to eventually fail, the solver will likely exit
48+
with a `dtmin` failure..
49+
50+
```@example debug
51+
prob = ODEProblem(sys, [], (0.0, 10.0))
52+
sol = solve(prob, Tsit5())
53+
```
54+
55+
We can use `debug_system` to log the failing assertions in each call to the RHS function.
56+
57+
```@repl debug
58+
dsys = debug_system(sys; functions = []);
59+
dprob = ODEProblem(dsys, [], (0.0, 10.0));
60+
dsol = solve(dprob, Tsit5());
61+
```
62+
63+
Note the logs containing the failed assertion and corresponding message. To temporarily disable
64+
logging in a system returned from `debug_system`, use `ModelingToolkit.ASSERTION_LOG_VARIABLE`.
65+
66+
```@repl debug
67+
dprob[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false;
68+
solve(drob, Tsit5());
69+
```
70+
3871
```@docs
3972
debug_system
4073
```

docs/src/basics/Variable_metadata.md

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
5454
5555
@variables i(t) [connect = Flow]
5656
@variables k(t) [connect = Stream]
57+
hasconnect(i)
58+
```
59+
60+
```@example connect
61+
getconnect(k)
5762
```
5863

5964
## Input or output
@@ -177,8 +182,45 @@ A variable can be marked `irreducible` to prevent it from being moved to an
177182
`observed` state. This forces the variable to be computed during solving so that
178183
it can be accessed in [callbacks](@ref events)
179184

180-
```julia
181-
@variable important_value [irreducible = true]
185+
```@example metadata
186+
@variables important_value [irreducible = true]
187+
isirreducible(important_value)
188+
```
189+
190+
## State Priority
191+
192+
When a model is structurally simplified, the algorithm will try to ensure that the variables with higher state priority become states of the system. A variable's state priority is a number set using the `state_priority` metadata.
193+
194+
```@example metadata
195+
@variables important_dof [state_priority = 10] unimportant_dof [state_priority = -2]
196+
state_priority(important_dof)
197+
```
198+
199+
## Units
200+
201+
Units for variables can be designated using symbolic metadata. For more information, please see the [model validation and units](@ref units) section of the docs. Note that `getunit` is not equivalent to `get_unit` - the former is a metadata getter for individual variables (and is provided so the same interface function for `unit` exists like other metadata), while the latter is used to handle more general symbolic expressions.
202+
203+
```@example metadata
204+
@variables speed [unit = u"m/s"]
205+
hasunit(speed)
206+
```
207+
208+
```@example metadata
209+
getunit(speed)
210+
```
211+
212+
## Miscellaneous metadata
213+
214+
User-defined metadata can be added using the `misc` metadata. This can be queried
215+
using the `hasmisc` and `getmisc` functions.
216+
217+
```@example metadata
218+
@variables u [misc = :conserved_parameter] y [misc = [2, 4, 6]]
219+
hasmisc(u)
220+
```
221+
222+
```@example metadata
223+
getmisc(y)
182224
```
183225

184226
## Additional functions

docs/src/tutorials/initialization.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ principles of initialization of DAE systems. Take a DAE written in semi-explicit
1414
form:
1515

1616
```math
17-
x' = f(x,y,t)\\
18-
0 = g(x,y,t)
17+
\begin{aligned}
18+
x^\prime &= f(x,y,t) \\
19+
0 &= g(x,y,t)
20+
\end{aligned}
1921
```
2022

2123
where ``x`` are the differential variables and ``y`` are the algebraic variables.

src/ModelingToolkit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ export initial_state, transition, activeState, entry, ticksInState, timeInState
247247
export @component, @mtkmodel, @mtkbuild
248248
export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbance,
249249
istunable, getdist, hasdist,
250-
tunable_parameters, isirreducible, getdescription, hasdescription
250+
tunable_parameters, isirreducible, getdescription, hasdescription,
251+
hasunit, getunit, hasconnect, getconnect,
252+
hasmisc, getmisc, state_priority
251253
export ode_order_lowering, dae_order_lowering, liouville_transform
252254
export PDESystem
253255
export Differential, expand_derivatives, @derivatives

src/debugging.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,59 @@ function debug_sub(ex, funcs; kw...)
4242
f in funcs ? logged_fun(f, args...; kw...) :
4343
maketerm(typeof(ex), f, args, metadata(ex))
4444
end
45+
46+
"""
47+
$(TYPEDSIGNATURES)
48+
49+
A function which returns `NaN` if `condition` fails, and `0.0` otherwise.
50+
"""
51+
function _nan_condition(condition::Bool)
52+
condition ? 0.0 : NaN
53+
end
54+
55+
@register_symbolic _nan_condition(condition::Bool)
56+
57+
"""
58+
$(TYPEDSIGNATURES)
59+
60+
A function which takes a condition `expr` and returns `NaN` if it is false,
61+
and zero if it is true. In case the condition is false and `log == true`,
62+
`message` will be logged as an `@error`.
63+
"""
64+
function _debug_assertion(expr::Bool, message::String, log::Bool)
65+
value = _nan_condition(expr)
66+
isnan(value) || return value
67+
log && @error message
68+
return value
69+
end
70+
71+
@register_symbolic _debug_assertion(expr::Bool, message::String, log::Bool)
72+
73+
"""
74+
Boolean parameter added to models returned from `debug_system` to control logging of
75+
assertions.
76+
"""
77+
const ASSERTION_LOG_VARIABLE = only(@parameters __log_assertions_ₘₜₖ::Bool = false)
78+
79+
"""
80+
$(TYPEDSIGNATURES)
81+
82+
Get a symbolic expression for all the assertions in `sys`. The expression returns `NaN`
83+
if any of the assertions fail, and `0.0` otherwise. If `ASSERTION_LOG_VARIABLE` is a
84+
parameter in the system, it will control whether the message associated with each
85+
assertion is logged when it fails.
86+
"""
87+
function get_assertions_expr(sys::AbstractSystem)
88+
asserts = assertions(sys)
89+
term = 0
90+
if is_parameter(sys, ASSERTION_LOG_VARIABLE)
91+
for (k, v) in asserts
92+
term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE)
93+
end
94+
else
95+
for (k, v) in asserts
96+
term += _nan_condition(k)
97+
end
98+
end
99+
return term
100+
end

src/systems/abstractsystem.jl

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,7 @@ for prop in [:eqs
984984
:gui_metadata
985985
:discrete_subsystems
986986
:parameter_dependencies
987+
:assertions
987988
:solved_unknowns
988989
:split_idxs
989990
:parent
@@ -1469,6 +1470,24 @@ end
14691470
"""
14701471
$(TYPEDSIGNATURES)
14711472
1473+
Get the assertions for a system `sys` and its subsystems.
1474+
"""
1475+
function assertions(sys::AbstractSystem)
1476+
has_assertions(sys) || return Dict{BasicSymbolic, String}()
1477+
1478+
asserts = get_assertions(sys)
1479+
systems = get_systems(sys)
1480+
namespaced_asserts = mapreduce(
1481+
merge!, systems; init = Dict{BasicSymbolic, String}()) do subsys
1482+
Dict{BasicSymbolic, String}(namespace_expr(k, subsys) => v
1483+
for (k, v) in assertions(subsys))
1484+
end
1485+
return merge(asserts, namespaced_asserts)
1486+
end
1487+
1488+
"""
1489+
$(TYPEDSIGNATURES)
1490+
14721491
Get the guesses for variables in the initialization system of the system `sys` and its subsystems.
14731492
14741493
See also [`initialization_equations`](@ref) and [`ModelingToolkit.get_guesses`](@ref).
@@ -2249,15 +2268,20 @@ macro mtkbuild(exprs...)
22492268
expr = exprs[1]
22502269
named_expr = ModelingToolkit.named_expr(expr)
22512270
name = named_expr.args[1]
2252-
kwargs = if length(exprs) > 1
2253-
NamedTuple{Tuple(ex.args[1] for ex in Base.tail(exprs))}(Tuple(ex.args[2]
2254-
for ex in Base.tail(exprs)))
2271+
kwargs = Base.tail(exprs)
2272+
kwargs = map(kwargs) do ex
2273+
@assert ex.head == :(=)
2274+
Expr(:kw, ex.args[1], ex.args[2])
2275+
end
2276+
if isempty(kwargs)
2277+
kwargs = ()
22552278
else
2256-
(;)
2279+
kwargs = (Expr(:parameters, kwargs...),)
22572280
end
2281+
call_expr = Expr(:call, structural_simplify, kwargs..., name)
22582282
esc(quote
22592283
$named_expr
2260-
$name = $structural_simplify($name; $(kwargs)...)
2284+
$name = $call_expr
22612285
end)
22622286
end
22632287

@@ -2279,6 +2303,13 @@ ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input
22792303
1 => 1
22802304
sin(P(t)) => 0.0
22812305
```
2306+
2307+
Additionally, all assertions in the system are optionally logged when they fail.
2308+
A new parameter is also added to the system which controls whether the message associated
2309+
with each assertion will be logged when the assertion fails. This parameter defaults to
2310+
`true` and can be toggled by symbolic indexing with
2311+
`ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example,
2312+
`prob.ps[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false` will disable logging.
22822313
"""
22832314
function debug_system(
22842315
sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], kw...)
@@ -2289,11 +2320,17 @@ function debug_system(
22892320
error("debug_system(sys) only works on systems with no sub-systems! Consider flattening it with flatten(sys) or structural_simplify(sys) first.")
22902321
end
22912322
if has_eqs(sys)
2292-
@set! sys.eqs = debug_sub.(equations(sys), Ref(functions); kw...)
2323+
eqs = debug_sub.(equations(sys), Ref(functions); kw...)
2324+
@set! sys.eqs = eqs
2325+
@set! sys.ps = unique!([get_ps(sys); ASSERTION_LOG_VARIABLE])
2326+
@set! sys.defaults = merge(get_defaults(sys), Dict(ASSERTION_LOG_VARIABLE => true))
22932327
end
22942328
if has_observed(sys)
22952329
@set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...)
22962330
end
2331+
if iscomplete(sys)
2332+
sys = complete(sys; split = is_split(sys))
2333+
end
22972334
return sys
22982335
end
22992336

@@ -3032,6 +3069,11 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
30323069
kwargs = merge(kwargs, (initialization_eqs = ieqs, guesses = guesses))
30333070
end
30343071

3072+
if has_assertions(basesys)
3073+
kwargs = merge(
3074+
kwargs, (; assertions = merge(get_assertions(basesys), get_assertions(sys))))
3075+
end
3076+
30353077
return T(args...; kwargs...)
30363078
end
30373079

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
168168
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
169169
[eq.rhs for eq in eqs]
170170

171+
if !isempty(assertions(sys))
172+
rhss[end] += unwrap(get_assertions_expr(sys))
173+
end
174+
171175
# TODO: add an optional check on the ordering of observed equations
172176
u = dvs
173177
p = reorder_parameters(sys, ps)
@@ -542,7 +546,7 @@ Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
542546
The arguments `dvs` and `ps` are used to set the order of the dependent
543547
variable and parameter vectors, respectively.
544548
"""
545-
struct ODEFunctionExpr{iip} end
549+
struct ODEFunctionExpr{iip, specialize} end
546550

547551
struct ODEFunctionClosure{O, I} <: Function
548552
f_oop::O
@@ -551,7 +555,7 @@ end
551555
(f::ODEFunctionClosure)(u, p, t) = f.f_oop(u, p, t)
552556
(f::ODEFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t)
553557

554-
function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
558+
function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns(sys),
555559
ps = parameters(sys), u0 = nothing;
556560
version = nothing, tgrad = false,
557561
jac = false, p = nothing,
@@ -560,14 +564,12 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
560564
steady_state = false,
561565
sparsity = false,
562566
observedfun_exp = nothing,
563-
kwargs...) where {iip}
567+
kwargs...) where {iip, specialize}
564568
if !iscomplete(sys)
565569
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunctionExpr`")
566570
end
567571
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
568572

569-
dict = Dict()
570-
571573
fsym = gensym(:f)
572574
_f = :($fsym = $ODEFunctionClosure($f_oop, $f_iip))
573575
tgradsym = gensym(:tgrad)
@@ -590,30 +592,28 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
590592
_jac = :($jacsym = nothing)
591593
end
592594

595+
Msym = gensym(:M)
593596
M = calculate_massmatrix(sys)
594-
595-
_M = if sparse && !(u0 === nothing || M === I)
596-
SparseArrays.sparse(M)
597+
if sparse && !(u0 === nothing || M === I)
598+
_M = :($Msym = $(SparseArrays.sparse(M)))
597599
elseif u0 === nothing || M === I
598-
M
600+
_M = :($Msym = $M)
599601
else
600-
ArrayInterface.restructure(u0 .* u0', M)
602+
_M = :($Msym = $(ArrayInterface.restructure(u0 .* u0', M)))
601603
end
602604

603605
jp_expr = sparse ? :($similar($(get_jac(sys)[]), Float64)) : :nothing
604606
ex = quote
605-
$_f
606-
$_tgrad
607-
$_jac
608-
M = $_M
609-
ODEFunction{$iip}($fsym,
610-
sys = $sys,
611-
jac = $jacsym,
612-
tgrad = $tgradsym,
613-
mass_matrix = M,
614-
jac_prototype = $jp_expr,
615-
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing),
616-
observed = $observedfun_exp)
607+
let $_f, $_tgrad, $_jac, $_M
608+
ODEFunction{$iip, $specialize}($fsym,
609+
sys = $sys,
610+
jac = $jacsym,
611+
tgrad = $tgradsym,
612+
mass_matrix = $Msym,
613+
jac_prototype = $jp_expr,
614+
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing),
615+
observed = $observedfun_exp)
616+
end
617617
end
618618
!linenumbers ? Base.remove_linenums!(ex) : ex
619619
end
@@ -622,6 +622,14 @@ function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
622622
ODEFunctionExpr{true}(sys, args...; kwargs...)
623623
end
624624

625+
function ODEFunctionExpr{true}(sys::AbstractODESystem, args...; kwargs...)
626+
return ODEFunctionExpr{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
627+
end
628+
629+
function ODEFunctionExpr{false}(sys::AbstractODESystem, args...; kwargs...)
630+
return ODEFunctionExpr{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
631+
end
632+
625633
"""
626634
```julia
627635
DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),

0 commit comments

Comments
 (0)