Skip to content

Commit b968b92

Browse files
feat: add assertions to function regardless of debug_system
add logging in `debug_system`
1 parent 41a43fd commit b968b92

File tree

4 files changed

+54
-21
lines changed

4 files changed

+54
-21
lines changed

src/debugging.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ function debug_sub(ex, funcs; kw...)
4343
maketerm(typeof(ex), f, args, metadata(ex))
4444
end
4545

46+
"""
47+
$(TYPEDSIGNATURES)
48+
49+
A function which returns `NaN` if `condition` fails, and `0.0` otherwise.
50+
"""
51+
function _nan_condition(condition::Bool)
52+
condition ? 0.0 : NaN
53+
end
54+
55+
@register_symbolic _nan_condition(condition::Bool)
56+
4657
"""
4758
$(TYPEDSIGNATURES)
4859
@@ -51,9 +62,10 @@ and zero if it is true. In case the condition is false and `log == true`,
5162
`message` will be logged as an `@error`.
5263
"""
5364
function _debug_assertion(expr::Bool, message::String, log::Bool)
54-
expr && return 0.0
65+
value = _nan_condition(expr)
66+
isnan(value) || return value
5567
log && @error message
56-
return NaN
68+
return value
5769
end
5870

5971
@register_symbolic _debug_assertion(expr::Bool, message::String, log::Bool)
@@ -67,13 +79,22 @@ const ASSERTION_LOG_VARIABLE = only(@parameters __log_assertions_ₘₜₖ::Bool
6779
"""
6880
$(TYPEDSIGNATURES)
6981
70-
Get a symbolic expression as per the requirement of `debug_system` for all the assertions
71-
in `assertions`. `is_split` denotes whether the corresponding system is a split system.
82+
Get a symbolic expression for all the assertions in `sys`. The expression returns `NaN`
83+
if any of the assertions fail, and `0.0` otherwise. If `ASSERTION_LOG_VARIABLE` is a
84+
parameter in the system, it will control whether the message associated with each
85+
assertion is logged when it fails.
7286
"""
73-
function get_assertions_expr(assertions::Dict{BasicSymbolic, String}, is_split::Bool)
87+
function get_assertions_expr(sys::AbstractSystem)
88+
asserts = assertions(sys)
7489
term = 0
75-
for (k, v) in assertions
76-
term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE)
90+
if is_parameter(sys, ASSERTION_LOG_VARIABLE)
91+
for (k, v) in asserts
92+
term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE)
93+
end
94+
else
95+
for (k, v) in asserts
96+
term += _nan_condition(k)
97+
end
7798
end
7899
return term
79100
end

src/systems/abstractsystem.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,12 +2303,11 @@ ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input
23032303
sin(P(t)) => 0.0
23042304
```
23052305
2306-
Additionally, all assertions in the system are validated in the equations. If any of
2307-
the conditions are false, the right hand side of at least one of the equations of
2308-
the system will evaluate to `NaN`. A new parameter is also added to the system which
2309-
controls whether the message associated with each assertion will be logged when the
2310-
assertion fails. This parameter defaults to `true` and can be toggled by
2311-
symbolic indexing with `ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example,
2306+
Additionally, all assertions in the system are optionally logged when they fail.
2307+
A new parameter is also added to the system which controls whether the message associated
2308+
with each assertion will be logged when the assertion fails. This parameter defaults to
2309+
`true` and can be toggled by symbolic indexing with
2310+
`ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example,
23122311
`prob.ps[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false` will disable logging.
23132312
"""
23142313
function debug_system(
@@ -2321,18 +2320,16 @@ function debug_system(
23212320
end
23222321
if has_eqs(sys)
23232322
eqs = debug_sub.(equations(sys), Ref(functions); kw...)
2324-
expr = get_assertions_expr(assertions(sys))
2325-
eqs[end] = eqs[end].lhs ~ eqs[end].rhs + expr
23262323
@set! sys.eqs = eqs
23272324
@set! sys.ps = unique!([get_ps(sys); ASSERTION_LOG_VARIABLE])
23282325
@set! sys.defaults = merge(get_defaults(sys), Dict(ASSERTION_LOG_VARIABLE => true))
2329-
if iscomplete(sys)
2330-
sys = complete(sys; split = is_split(sys))
2331-
end
23322326
end
23332327
if has_observed(sys)
23342328
@set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...)
23352329
end
2330+
if iscomplete(sys)
2331+
sys = complete(sys; split = is_split(sys))
2332+
end
23362333
return sys
23372334
end
23382335

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
168168
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
169169
[eq.rhs for eq in eqs]
170170

171+
if !isempty(assertions(sys))
172+
rhss[end] += unwrap(get_assertions_expr(sys))
173+
end
174+
171175
# TODO: add an optional check on the ordering of observed equations
172176
u = dvs
173177
p = reorder_parameters(sys, ps)

test/debugging.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,23 @@ using ModelingToolkit: t_nounits as t, D_nounits as D, ASSERTION_LOG_VARIABLE
88
sys_ode = structural_simplify(inner_ode)
99
sys_sde = structural_simplify(inner_sde)
1010

11-
@testset "`debug_system` adds assertions" begin
11+
@testset "assertions are present in generated `f`" begin
12+
@testset "$(typeof(sys))" for (Problem, sys, alg) in [
13+
(ODEProblem, sys_ode, Tsit5()), (SDEProblem, sys_sde, ImplicitEM())]
14+
@test !is_parameter(sys, ASSERTION_LOG_VARIABLE)
15+
prob = Problem(sys, [x => 0.1], (0.0, 5.0))
16+
sol = solve(prob, alg)
17+
@test !SciMLBase.successful_retcode(sol)
18+
@test isnan(prob.f.f([0.0], prob.p, sol.t[end])[1])
19+
end
20+
end
21+
22+
@testset "`debug_system` adds logging" begin
1223
@testset "$(typeof(sys))" for (Problem, sys, alg) in [
1324
(ODEProblem, sys_ode, Tsit5()), (SDEProblem, sys_sde, ImplicitEM())]
1425
dsys = debug_system(sys; functions = [])
1526
@test is_parameter(dsys, ASSERTION_LOG_VARIABLE)
16-
prob = Problem(dsys, [x => 1.0], (0.0, 5.0))
27+
prob = Problem(dsys, [x => 0.1], (0.0, 5.0))
1728
sol = solve(prob, alg)
1829
@test !SciMLBase.successful_retcode(sol)
1930
prob.ps[ASSERTION_LOG_VARIABLE] = true
@@ -29,7 +40,7 @@ end
2940
@mtkbuild outer = ctor(Equation[], t; systems = [inner])
3041
dsys = debug_system(outer; functions = [])
3142
@test is_parameter(dsys, ASSERTION_LOG_VARIABLE)
32-
prob = Problem(dsys, [inner.x => 1.0], (0.0, 5.0))
43+
prob = Problem(dsys, [inner.x => 0.1], (0.0, 5.0))
3344
sol = solve(prob, alg)
3445
@test !SciMLBase.successful_retcode(sol)
3546
prob.ps[ASSERTION_LOG_VARIABLE] = true

0 commit comments

Comments
 (0)