From 5f8cda3a59aa3333c9f1144562ff7642f82bd4b8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Feb 2025 16:27:12 +0530 Subject: [PATCH 01/11] feat: add `assertions` field to `ODESystem` --- src/systems/abstractsystem.jl | 24 ++++++++++++++++++++++++ src/systems/diffeqs/odesystem.jl | 14 +++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index dc47457825..ad46e4529e 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -983,6 +983,7 @@ for prop in [:eqs :gui_metadata :discrete_subsystems :parameter_dependencies + :assertions :solved_unknowns :split_idxs :parent @@ -1468,6 +1469,24 @@ end """ $(TYPEDSIGNATURES) +Get the assertions for a system `sys` and its subsystems. +""" +function assertions(sys::AbstractSystem) + has_assertions(sys) || return Dict{BasicSymbolic, String}() + + asserts = get_assertions(sys) + systems = get_systems(sys) + namespaced_asserts = mapreduce( + merge!, systems; init = Dict{BasicSymbolic, String}()) do subsys + Dict{BasicSymbolic, String}(namespace_expr(k, subsys) => v + for (k, v) in assertions(subsys)) + end + return merge(asserts, namespaced_asserts) +end + +""" +$(TYPEDSIGNATURES) + Get the guesses for variables in the initialization system of the system `sys` and its subsystems. See also [`initialization_equations`](@ref) and [`ModelingToolkit.get_guesses`](@ref). @@ -3036,6 +3055,11 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; kwargs = merge(kwargs, (initialization_eqs = ieqs, guesses = guesses)) end + if has_assertions(basesys) + kwargs = merge( + kwargs, (; assertions = merge(get_assertions(basesys), get_assertions(sys)))) + end + return T(args...; kwargs...) end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index c735b52c37..e32b43bd2a 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -137,6 +137,11 @@ struct ODESystem <: AbstractODESystem """ parameter_dependencies::Vector{Equation} """ + Mapping of conditions which should be true throughout the solve to corresponding error + messages. These will be added to the equations when calling `debug_system`. + """ + assertions::Dict{BasicSymbolic, String} + """ Metadata for the system, to be used by downstream packages. """ metadata::Any @@ -190,7 +195,7 @@ struct ODESystem <: AbstractODESystem jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching, initializesystem, initialization_eqs, schedule, connector_type, preface, cevents, - devents, parameter_dependencies, + devents, parameter_dependencies, assertions = Dict{BasicSymbolic, String}(), metadata = nothing, gui_metadata = nothing, is_dde = false, tstops = [], tearing_state = nothing, substitutions = nothing, complete = false, index_cache = nothing, @@ -210,7 +215,7 @@ struct ODESystem <: AbstractODESystem new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching, initializesystem, initialization_eqs, schedule, connector_type, preface, - cevents, devents, parameter_dependencies, metadata, + cevents, devents, parameter_dependencies, assertions, metadata, gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache, discrete_subsystems, solved_unknowns, split_idxs, parent) end @@ -235,6 +240,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; continuous_events = nothing, discrete_events = nothing, parameter_dependencies = Equation[], + assertions = Dict(), checks = true, metadata = nothing, gui_metadata = nothing, @@ -286,12 +292,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; if is_dde === nothing is_dde = _check_if_dde(deqs, iv′, systems) end + assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions) ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, nothing, initializesystem, initialization_eqs, schedule, connector_type, preface, cont_callbacks, - disc_callbacks, parameter_dependencies, + disc_callbacks, parameter_dependencies, assertions, metadata, gui_metadata, is_dde, tstops, checks = checks) end @@ -364,6 +371,7 @@ function flatten(sys::ODESystem, noeqs = false) name = nameof(sys), description = description(sys), initialization_eqs = initialization_equations(sys), + assertions = assertions(sys), is_dde = is_dde(sys), tstops = symbolic_tstops(sys), metadata = get_metadata(sys), From aa6f3402a5050f3d4c92eea45de8fbacdc98cc96 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Feb 2025 16:27:39 +0530 Subject: [PATCH 02/11] feat: add assertions support to `debug_system` --- src/debugging.jl | 18 ++++++++++++++++++ src/systems/abstractsystem.jl | 10 +++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/debugging.jl b/src/debugging.jl index a1a168d8dd..b30a1ac2ac 100644 --- a/src/debugging.jl +++ b/src/debugging.jl @@ -42,3 +42,21 @@ function debug_sub(ex, funcs; kw...) f in funcs ? logged_fun(f, args...; kw...) : maketerm(typeof(ex), f, args, metadata(ex)) end + +function _debug_assertion(expr::Bool, message::String, log::Bool) + expr && return 0.0 + log && @error message + return NaN +end + +@register_symbolic _debug_assertion(expr::Bool, message::String, log::Bool) + +const ASSERTION_LOG_VARIABLE = only(@parameters __log_assertions_ₘₜₖ::Bool = false) + +function get_assertions_expr(assertions::Dict{BasicSymbolic, String}) + term = 0 + for (k, v) in assertions + term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE) + end + return term +end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index ad46e4529e..80b27e4e84 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2312,7 +2312,15 @@ function debug_system( error("debug_system(sys) only works on systems with no sub-systems! Consider flattening it with flatten(sys) or structural_simplify(sys) first.") end if has_eqs(sys) - @set! sys.eqs = debug_sub.(equations(sys), Ref(functions); kw...) + eqs = debug_sub.(equations(sys), Ref(functions); kw...) + expr = get_assertions_expr(assertions(sys)) + eqs[end] = eqs[end].lhs ~ eqs[end].rhs + expr + @set! sys.eqs = eqs + @set! sys.ps = unique!([get_ps(sys); ASSERTION_LOG_VARIABLE]) + @set! sys.defaults = merge(get_defaults(sys), Dict(ASSERTION_LOG_VARIABLE => true)) + if iscomplete(sys) + sys = complete(sys; split = is_split(sys)) + end end if has_observed(sys) @set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...) From b5e6dd93f57c5e9bc51dfb1278c0411d8425724c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Feb 2025 17:13:22 +0530 Subject: [PATCH 03/11] feat: add `assertions` field to `SDESystem` --- src/systems/diffeqs/sdesystem.jl | 19 ++++++++++++++----- src/systems/systems.jl | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 212e20d743..8b7bc8b0d8 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -126,6 +126,11 @@ struct SDESystem <: AbstractODESystem """ parameter_dependencies::Vector{Equation} """ + Mapping of conditions which should be true throughout the solve to corresponding error + messages. These will be added to the equations when calling `debug_system`. + """ + assertions::Dict{BasicSymbolic, String} + """ Metadata for the system, to be used by downstream packages. """ metadata::Any @@ -159,7 +164,9 @@ struct SDESystem <: AbstractODESystem function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, - cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing, + cevents, devents, parameter_dependencies, assertions = Dict{ + BasicSymbolic, Nothing}, + metadata = nothing, gui_metadata = nothing, complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false, is_dde = false, isscheduled = false; @@ -185,9 +192,8 @@ struct SDESystem <: AbstractODESystem new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents, - devents, - parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise, - is_dde, isscheduled) + devents, parameter_dependencies, assertions, metadata, gui_metadata, complete, + index_cache, parent, is_scalar_noise, is_dde, isscheduled) end end @@ -209,6 +215,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv continuous_events = nothing, discrete_events = nothing, parameter_dependencies = Equation[], + assertions = Dict{BasicSymbolic, String}(), metadata = nothing, gui_metadata = nothing, complete = false, @@ -261,11 +268,12 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv if is_dde === nothing is_dde = _check_if_dde(deqs, iv′, systems) end + assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions) SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, - cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata, + cont_callbacks, disc_callbacks, parameter_dependencies, assertions, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks) end @@ -378,6 +386,7 @@ function ODESystem(sys::SDESystem) newsys = ODESystem(neweqs, get_iv(sys), unknowns(sys), parameters(sys); parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys), continuous_events = continuous_events(sys), discrete_events = discrete_events(sys), + assertions = assertions(sys), name = nameof(sys), description = description(sys), metadata = get_metadata(sys)) @set newsys.parent = sys end diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 9c8c272c5c..f8630f2d20 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -165,7 +165,7 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal return SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs, get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys); name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys), defaults = defaults(sys), - parameter_dependencies = parameter_dependencies(sys), + parameter_dependencies = parameter_dependencies(sys), assertions = assertions(sys), guesses = guesses(sys), initialization_eqs = initialization_equations(sys)) end end From 8e23967fd7c7e18a6d621252672f58753a07c06b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Feb 2025 17:13:31 +0530 Subject: [PATCH 04/11] test: test new assertions functionality --- test/debugging.jl | 39 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 40 insertions(+) create mode 100644 test/debugging.jl diff --git a/test/debugging.jl b/test/debugging.jl new file mode 100644 index 0000000000..f127df6cf8 --- /dev/null +++ b/test/debugging.jl @@ -0,0 +1,39 @@ +using ModelingToolkit, OrdinaryDiffEq, StochasticDiffEq, SymbolicIndexingInterface +using ModelingToolkit: t_nounits as t, D_nounits as D, ASSERTION_LOG_VARIABLE + +@variables x(t) +@brownian a +@named inner_ode = ODESystem(D(x) ~ -sqrt(x), t; assertions = [(x > 0) => "ohno"]) +@named inner_sde = System([D(x) ~ -sqrt(x) + a], t; assertions = [(x > 0) => "ohno"]) +sys_ode = structural_simplify(inner_ode) +sys_sde = structural_simplify(inner_sde) + +@testset "`debug_system` adds assertions" begin + @testset "$(typeof(sys))" for (Problem, sys, alg) in [ + (ODEProblem, sys_ode, Tsit5()), (SDEProblem, sys_sde, ImplicitEM())] + dsys = debug_system(sys; functions = []) + @test is_parameter(dsys, ASSERTION_LOG_VARIABLE) + prob = Problem(dsys, [x => 1.0], (0.0, 5.0)) + sol = solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + prob.ps[ASSERTION_LOG_VARIABLE] = true + sol = @test_logs (:error, r"ohno") match_mode=:any solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + end +end + +@testset "Hierarchical system" begin + @testset "$(typeof(inner))" for (ctor, Problem, inner, alg) in [ + (ODESystem, ODEProblem, inner_ode, Tsit5()), + (System, SDEProblem, inner_sde, ImplicitEM())] + @mtkbuild outer = ctor(Equation[], t; systems = [inner]) + dsys = debug_system(outer; functions = []) + @test is_parameter(dsys, ASSERTION_LOG_VARIABLE) + prob = Problem(dsys, [inner.x => 1.0], (0.0, 5.0)) + sol = solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + prob.ps[ASSERTION_LOG_VARIABLE] = true + sol = @test_logs (:error, r"ohno") match_mode=:any solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 52875fdae5..9537b1b44e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -92,6 +92,7 @@ end @safetestset "IfLifting Test" include("if_lifting.jl") @safetestset "Analysis Points Test" include("analysis_points.jl") @safetestset "Causal Variables Connection Test" include("causal_variables_connection.jl") + @safetestset "Debugging Test" include("debugging.jl") end end From 5d2e8ede6355369b7f064864eb83852fcb5aec2c Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 3 Feb 2025 04:37:21 -0800 Subject: [PATCH 05/11] Update src/systems/diffeqs/odesystem.jl Co-authored-by: Fredrik Bagge Carlson --- src/systems/diffeqs/odesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index e32b43bd2a..96fd8534ae 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -137,7 +137,7 @@ struct ODESystem <: AbstractODESystem """ parameter_dependencies::Vector{Equation} """ - Mapping of conditions which should be true throughout the solve to corresponding error + Mapping of conditions which should be true throughout the solution process to corresponding error messages. These will be added to the equations when calling `debug_system`. """ assertions::Dict{BasicSymbolic, String} From f9b8f52ee8254400849e435f65ed64eea09c29a4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Feb 2025 18:09:03 +0530 Subject: [PATCH 06/11] Update src/systems/diffeqs/sdesystem.jl Co-authored-by: Fredrik Bagge Carlson --- src/systems/diffeqs/sdesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 8b7bc8b0d8..92a8ddd710 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -126,7 +126,7 @@ struct SDESystem <: AbstractODESystem """ parameter_dependencies::Vector{Equation} """ - Mapping of conditions which should be true throughout the solve to corresponding error + Mapping of conditions which should be true throughout the solution process to corresponding error messages. These will be added to the equations when calling `debug_system`. """ assertions::Dict{BasicSymbolic, String} From 41a43fdf506ce291241a9e0077a68570c662cf57 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Feb 2025 18:19:22 +0530 Subject: [PATCH 07/11] docs: add documentation for assertions functionality --- docs/src/basics/Debugging.md | 17 +++++++++++++++++ src/debugging.jl | 19 ++++++++++++++++++- src/systems/abstractsystem.jl | 8 ++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/docs/src/basics/Debugging.md b/docs/src/basics/Debugging.md index d5c51ec0c1..b2524832d9 100644 --- a/docs/src/basics/Debugging.md +++ b/docs/src/basics/Debugging.md @@ -35,6 +35,23 @@ dsol = solve(dprob, Tsit5()); Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `√` function. We could have figured that out ourselves, but it is not always so obvious for more complex models. +Suppose we also want to validate that `u1 + u2 >= 2.0`. We can do this via the assertions functionality. + +```@example debug +@mtkbuild sys = ODESystem(eqs, t; defaults, assertions = [(u1 + u2 >= 2.0) => "Oh no!"]) +``` + +The assertions must be an iterable of pairs, where the first element is the symbolic condition and +the second is the message to be logged when the condition fails. + +```@repl debug +dsys = debug_system(sys; functions = []); +dprob = ODEProblem(dsys, [], (0.0, 10.0)); +dsol = solve(dprob, Tsit5()); +``` + +Note the messages containing the failed assertion and corresponding message. + ```@docs debug_system ``` diff --git a/src/debugging.jl b/src/debugging.jl index b30a1ac2ac..266819d60b 100644 --- a/src/debugging.jl +++ b/src/debugging.jl @@ -43,6 +43,13 @@ function debug_sub(ex, funcs; kw...) maketerm(typeof(ex), f, args, metadata(ex)) end +""" + $(TYPEDSIGNATURES) + +A function which takes a condition `expr` and returns `NaN` if it is false, +and zero if it is true. In case the condition is false and `log == true`, +`message` will be logged as an `@error`. +""" function _debug_assertion(expr::Bool, message::String, log::Bool) expr && return 0.0 log && @error message @@ -51,9 +58,19 @@ end @register_symbolic _debug_assertion(expr::Bool, message::String, log::Bool) +""" +Boolean parameter added to models returned from `debug_system` to control logging of +assertions. +""" const ASSERTION_LOG_VARIABLE = only(@parameters __log_assertions_ₘₜₖ::Bool = false) -function get_assertions_expr(assertions::Dict{BasicSymbolic, String}) +""" + $(TYPEDSIGNATURES) + +Get a symbolic expression as per the requirement of `debug_system` for all the assertions +in `assertions`. `is_split` denotes whether the corresponding system is a split system. +""" +function get_assertions_expr(assertions::Dict{BasicSymbolic, String}, is_split::Bool) term = 0 for (k, v) in assertions term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 80b27e4e84..7857cc4c74 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2302,6 +2302,14 @@ ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input 1 => 1 sin(P(t)) => 0.0 ``` + +Additionally, all assertions in the system are validated in the equations. If any of +the conditions are false, the right hand side of at least one of the equations of +the system will evaluate to `NaN`. A new parameter is also added to the system which +controls whether the message associated with each assertion will be logged when the +assertion fails. This parameter defaults to `true` and can be toggled by +symbolic indexing with `ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example, +`prob.ps[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false` will disable logging. """ function debug_system( sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], kw...) From b968b9244d2f9d4d99f7ed304f64a76eabcda171 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 10 Feb 2025 13:21:34 +0530 Subject: [PATCH 08/11] feat: add assertions to function regardless of `debug_system` add logging in `debug_system` --- src/debugging.jl | 35 +++++++++++++++++++----- src/systems/abstractsystem.jl | 19 ++++++------- src/systems/diffeqs/abstractodesystem.jl | 4 +++ test/debugging.jl | 17 ++++++++++-- 4 files changed, 54 insertions(+), 21 deletions(-) diff --git a/src/debugging.jl b/src/debugging.jl index 266819d60b..c16b47c2e3 100644 --- a/src/debugging.jl +++ b/src/debugging.jl @@ -43,6 +43,17 @@ function debug_sub(ex, funcs; kw...) maketerm(typeof(ex), f, args, metadata(ex)) end +""" + $(TYPEDSIGNATURES) + +A function which returns `NaN` if `condition` fails, and `0.0` otherwise. +""" +function _nan_condition(condition::Bool) + condition ? 0.0 : NaN +end + +@register_symbolic _nan_condition(condition::Bool) + """ $(TYPEDSIGNATURES) @@ -51,9 +62,10 @@ and zero if it is true. In case the condition is false and `log == true`, `message` will be logged as an `@error`. """ function _debug_assertion(expr::Bool, message::String, log::Bool) - expr && return 0.0 + value = _nan_condition(expr) + isnan(value) || return value log && @error message - return NaN + return value end @register_symbolic _debug_assertion(expr::Bool, message::String, log::Bool) @@ -67,13 +79,22 @@ const ASSERTION_LOG_VARIABLE = only(@parameters __log_assertions_ₘₜₖ::Bool """ $(TYPEDSIGNATURES) -Get a symbolic expression as per the requirement of `debug_system` for all the assertions -in `assertions`. `is_split` denotes whether the corresponding system is a split system. +Get a symbolic expression for all the assertions in `sys`. The expression returns `NaN` +if any of the assertions fail, and `0.0` otherwise. If `ASSERTION_LOG_VARIABLE` is a +parameter in the system, it will control whether the message associated with each +assertion is logged when it fails. """ -function get_assertions_expr(assertions::Dict{BasicSymbolic, String}, is_split::Bool) +function get_assertions_expr(sys::AbstractSystem) + asserts = assertions(sys) term = 0 - for (k, v) in assertions - term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE) + if is_parameter(sys, ASSERTION_LOG_VARIABLE) + for (k, v) in asserts + term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE) + end + else + for (k, v) in asserts + term += _nan_condition(k) + end end return term end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 7857cc4c74..8213b8f241 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2303,12 +2303,11 @@ ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input sin(P(t)) => 0.0 ``` -Additionally, all assertions in the system are validated in the equations. If any of -the conditions are false, the right hand side of at least one of the equations of -the system will evaluate to `NaN`. A new parameter is also added to the system which -controls whether the message associated with each assertion will be logged when the -assertion fails. This parameter defaults to `true` and can be toggled by -symbolic indexing with `ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example, +Additionally, all assertions in the system are optionally logged when they fail. +A new parameter is also added to the system which controls whether the message associated +with each assertion will be logged when the assertion fails. This parameter defaults to +`true` and can be toggled by symbolic indexing with +`ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example, `prob.ps[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false` will disable logging. """ function debug_system( @@ -2321,18 +2320,16 @@ function debug_system( end if has_eqs(sys) eqs = debug_sub.(equations(sys), Ref(functions); kw...) - expr = get_assertions_expr(assertions(sys)) - eqs[end] = eqs[end].lhs ~ eqs[end].rhs + expr @set! sys.eqs = eqs @set! sys.ps = unique!([get_ps(sys); ASSERTION_LOG_VARIABLE]) @set! sys.defaults = merge(get_defaults(sys), Dict(ASSERTION_LOG_VARIABLE => true)) - if iscomplete(sys) - sys = complete(sys; split = is_split(sys)) - end end if has_observed(sys) @set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...) end + if iscomplete(sys) + sys = complete(sys; split = is_split(sys)) + end return sys end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 8d0720a1d2..ea55dce388 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -168,6 +168,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] : [eq.rhs for eq in eqs] + if !isempty(assertions(sys)) + rhss[end] += unwrap(get_assertions_expr(sys)) + end + # TODO: add an optional check on the ordering of observed equations u = dvs p = reorder_parameters(sys, ps) diff --git a/test/debugging.jl b/test/debugging.jl index f127df6cf8..aad36eb995 100644 --- a/test/debugging.jl +++ b/test/debugging.jl @@ -8,12 +8,23 @@ using ModelingToolkit: t_nounits as t, D_nounits as D, ASSERTION_LOG_VARIABLE sys_ode = structural_simplify(inner_ode) sys_sde = structural_simplify(inner_sde) -@testset "`debug_system` adds assertions" begin +@testset "assertions are present in generated `f`" begin + @testset "$(typeof(sys))" for (Problem, sys, alg) in [ + (ODEProblem, sys_ode, Tsit5()), (SDEProblem, sys_sde, ImplicitEM())] + @test !is_parameter(sys, ASSERTION_LOG_VARIABLE) + prob = Problem(sys, [x => 0.1], (0.0, 5.0)) + sol = solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + @test isnan(prob.f.f([0.0], prob.p, sol.t[end])[1]) + end +end + +@testset "`debug_system` adds logging" begin @testset "$(typeof(sys))" for (Problem, sys, alg) in [ (ODEProblem, sys_ode, Tsit5()), (SDEProblem, sys_sde, ImplicitEM())] dsys = debug_system(sys; functions = []) @test is_parameter(dsys, ASSERTION_LOG_VARIABLE) - prob = Problem(dsys, [x => 1.0], (0.0, 5.0)) + prob = Problem(dsys, [x => 0.1], (0.0, 5.0)) sol = solve(prob, alg) @test !SciMLBase.successful_retcode(sol) prob.ps[ASSERTION_LOG_VARIABLE] = true @@ -29,7 +40,7 @@ end @mtkbuild outer = ctor(Equation[], t; systems = [inner]) dsys = debug_system(outer; functions = []) @test is_parameter(dsys, ASSERTION_LOG_VARIABLE) - prob = Problem(dsys, [inner.x => 1.0], (0.0, 5.0)) + prob = Problem(dsys, [inner.x => 0.1], (0.0, 5.0)) sol = solve(prob, alg) @test !SciMLBase.successful_retcode(sol) prob.ps[ASSERTION_LOG_VARIABLE] = true From 3ca004e4b11e2513121c4e68f3d56ad6918effa2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 10 Feb 2025 15:11:59 +0530 Subject: [PATCH 09/11] docs: update assertions documentation --- docs/src/basics/Debugging.md | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/docs/src/basics/Debugging.md b/docs/src/basics/Debugging.md index b2524832d9..b432810bd0 100644 --- a/docs/src/basics/Debugging.md +++ b/docs/src/basics/Debugging.md @@ -42,7 +42,17 @@ Suppose we also want to validate that `u1 + u2 >= 2.0`. We can do this via the a ``` The assertions must be an iterable of pairs, where the first element is the symbolic condition and -the second is the message to be logged when the condition fails. +the second is a message to be logged when the condition fails. All assertions are added to the +generated code and will cause the solver to reject steps that fail the assertions. For systems such +as the above where the assertion is guaranteed to eventually fail, the solver will likely exit +with a `dtmin` failure.. + +```@example debug +prob = ODEProblem(sys, [], (0.0, 10.0)) +sol = solve(prob, Tsit5()) +``` + +We can use `debug_system` to log the failing assertions in each call to the RHS function. ```@repl debug dsys = debug_system(sys; functions = []); @@ -50,7 +60,13 @@ dprob = ODEProblem(dsys, [], (0.0, 10.0)); dsol = solve(dprob, Tsit5()); ``` -Note the messages containing the failed assertion and corresponding message. +Note the logs containing the failed assertion and corresponding message. To temporarily disable +logging in a system returned from `debug_system`, use `ModelingToolkit.ASSERTION_LOG_VARIABLE`. + +```@repl debug +dprob[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false; +solve(drob, Tsit5()); +``` ```@docs debug_system From 8fe1e957d5a1239bab24b219bc0cdab842659334 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 10 Feb 2025 16:42:16 +0530 Subject: [PATCH 10/11] docs: fix broken example blocks --- docs/src/basics/Variable_metadata.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/basics/Variable_metadata.md b/docs/src/basics/Variable_metadata.md index 44dfb30327..b2cc472f2f 100644 --- a/docs/src/basics/Variable_metadata.md +++ b/docs/src/basics/Variable_metadata.md @@ -183,7 +183,7 @@ A variable can be marked `irreducible` to prevent it from being moved to an it can be accessed in [callbacks](@ref events) ```@example metadata -@variable important_value [irreducible = true] +@variables important_value [irreducible = true] isirreducible(important_value) ``` @@ -192,7 +192,7 @@ isirreducible(important_value) When a model is structurally simplified, the algorithm will try to ensure that the variables with higher state priority become states of the system. A variable's state priority is a number set using the `state_priority` metadata. ```@example metadata -@variable important_dof [state_priority = 10] unimportant_dof [state_priority = -2] +@variables important_dof [state_priority = 10] unimportant_dof [state_priority = -2] state_priority(important_dof) ``` @@ -201,7 +201,7 @@ state_priority(important_dof) Units for variables can be designated using symbolic metadata. For more information, please see the [model validation and units](@ref units) section of the docs. Note that `getunit` is not equivalent to `get_unit` - the former is a metadata getter for individual variables (and is provided so the same interface function for `unit` exists like other metadata), while the latter is used to handle more general symbolic expressions. ```@example metadata -@variable speed [unit = u"m/s"] +@variables speed [unit = u"m/s"] hasunit(speed) ``` From 7cd399f82b0be3552789c441b8268769ccc4bb68 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 10 Feb 2025 18:13:54 +0530 Subject: [PATCH 11/11] feat: export `state_priority` --- src/ModelingToolkit.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index ed99fb2772..09b59c4ed6 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -249,7 +249,7 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc istunable, getdist, hasdist, tunable_parameters, isirreducible, getdescription, hasdescription, hasunit, getunit, hasconnect, getconnect, - hasmisc, getmisc + hasmisc, getmisc, state_priority export ode_order_lowering, dae_order_lowering, liouville_transform export PDESystem export Differential, expand_derivatives, @derivatives