diff --git a/docs/src/basics/Debugging.md b/docs/src/basics/Debugging.md index d5c51ec0c1..b432810bd0 100644 --- a/docs/src/basics/Debugging.md +++ b/docs/src/basics/Debugging.md @@ -35,6 +35,39 @@ dsol = solve(dprob, Tsit5()); Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `√` function. We could have figured that out ourselves, but it is not always so obvious for more complex models. +Suppose we also want to validate that `u1 + u2 >= 2.0`. We can do this via the assertions functionality. + +```@example debug +@mtkbuild sys = ODESystem(eqs, t; defaults, assertions = [(u1 + u2 >= 2.0) => "Oh no!"]) +``` + +The assertions must be an iterable of pairs, where the first element is the symbolic condition and +the second is a message to be logged when the condition fails. All assertions are added to the +generated code and will cause the solver to reject steps that fail the assertions. For systems such +as the above where the assertion is guaranteed to eventually fail, the solver will likely exit +with a `dtmin` failure.. + +```@example debug +prob = ODEProblem(sys, [], (0.0, 10.0)) +sol = solve(prob, Tsit5()) +``` + +We can use `debug_system` to log the failing assertions in each call to the RHS function. + +```@repl debug +dsys = debug_system(sys; functions = []); +dprob = ODEProblem(dsys, [], (0.0, 10.0)); +dsol = solve(dprob, Tsit5()); +``` + +Note the logs containing the failed assertion and corresponding message. To temporarily disable +logging in a system returned from `debug_system`, use `ModelingToolkit.ASSERTION_LOG_VARIABLE`. + +```@repl debug +dprob[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false; +solve(drob, Tsit5()); +``` + ```@docs debug_system ``` diff --git a/docs/src/basics/Variable_metadata.md b/docs/src/basics/Variable_metadata.md index 44dfb30327..b2cc472f2f 100644 --- a/docs/src/basics/Variable_metadata.md +++ b/docs/src/basics/Variable_metadata.md @@ -183,7 +183,7 @@ A variable can be marked `irreducible` to prevent it from being moved to an it can be accessed in [callbacks](@ref events) ```@example metadata -@variable important_value [irreducible = true] +@variables important_value [irreducible = true] isirreducible(important_value) ``` @@ -192,7 +192,7 @@ isirreducible(important_value) When a model is structurally simplified, the algorithm will try to ensure that the variables with higher state priority become states of the system. A variable's state priority is a number set using the `state_priority` metadata. ```@example metadata -@variable important_dof [state_priority = 10] unimportant_dof [state_priority = -2] +@variables important_dof [state_priority = 10] unimportant_dof [state_priority = -2] state_priority(important_dof) ``` @@ -201,7 +201,7 @@ state_priority(important_dof) Units for variables can be designated using symbolic metadata. For more information, please see the [model validation and units](@ref units) section of the docs. Note that `getunit` is not equivalent to `get_unit` - the former is a metadata getter for individual variables (and is provided so the same interface function for `unit` exists like other metadata), while the latter is used to handle more general symbolic expressions. ```@example metadata -@variable speed [unit = u"m/s"] +@variables speed [unit = u"m/s"] hasunit(speed) ``` diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index ed99fb2772..09b59c4ed6 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -249,7 +249,7 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc istunable, getdist, hasdist, tunable_parameters, isirreducible, getdescription, hasdescription, hasunit, getunit, hasconnect, getconnect, - hasmisc, getmisc + hasmisc, getmisc, state_priority export ode_order_lowering, dae_order_lowering, liouville_transform export PDESystem export Differential, expand_derivatives, @derivatives diff --git a/src/debugging.jl b/src/debugging.jl index a1a168d8dd..c16b47c2e3 100644 --- a/src/debugging.jl +++ b/src/debugging.jl @@ -42,3 +42,59 @@ function debug_sub(ex, funcs; kw...) f in funcs ? logged_fun(f, args...; kw...) : maketerm(typeof(ex), f, args, metadata(ex)) end + +""" + $(TYPEDSIGNATURES) + +A function which returns `NaN` if `condition` fails, and `0.0` otherwise. +""" +function _nan_condition(condition::Bool) + condition ? 0.0 : NaN +end + +@register_symbolic _nan_condition(condition::Bool) + +""" + $(TYPEDSIGNATURES) + +A function which takes a condition `expr` and returns `NaN` if it is false, +and zero if it is true. In case the condition is false and `log == true`, +`message` will be logged as an `@error`. +""" +function _debug_assertion(expr::Bool, message::String, log::Bool) + value = _nan_condition(expr) + isnan(value) || return value + log && @error message + return value +end + +@register_symbolic _debug_assertion(expr::Bool, message::String, log::Bool) + +""" +Boolean parameter added to models returned from `debug_system` to control logging of +assertions. +""" +const ASSERTION_LOG_VARIABLE = only(@parameters __log_assertions_ₘₜₖ::Bool = false) + +""" + $(TYPEDSIGNATURES) + +Get a symbolic expression for all the assertions in `sys`. The expression returns `NaN` +if any of the assertions fail, and `0.0` otherwise. If `ASSERTION_LOG_VARIABLE` is a +parameter in the system, it will control whether the message associated with each +assertion is logged when it fails. +""" +function get_assertions_expr(sys::AbstractSystem) + asserts = assertions(sys) + term = 0 + if is_parameter(sys, ASSERTION_LOG_VARIABLE) + for (k, v) in asserts + term += _debug_assertion(k, "Assertion $k failed:\n$v", ASSERTION_LOG_VARIABLE) + end + else + for (k, v) in asserts + term += _nan_condition(k) + end + end + return term +end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index dc47457825..8213b8f241 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -983,6 +983,7 @@ for prop in [:eqs :gui_metadata :discrete_subsystems :parameter_dependencies + :assertions :solved_unknowns :split_idxs :parent @@ -1468,6 +1469,24 @@ end """ $(TYPEDSIGNATURES) +Get the assertions for a system `sys` and its subsystems. +""" +function assertions(sys::AbstractSystem) + has_assertions(sys) || return Dict{BasicSymbolic, String}() + + asserts = get_assertions(sys) + systems = get_systems(sys) + namespaced_asserts = mapreduce( + merge!, systems; init = Dict{BasicSymbolic, String}()) do subsys + Dict{BasicSymbolic, String}(namespace_expr(k, subsys) => v + for (k, v) in assertions(subsys)) + end + return merge(asserts, namespaced_asserts) +end + +""" +$(TYPEDSIGNATURES) + Get the guesses for variables in the initialization system of the system `sys` and its subsystems. See also [`initialization_equations`](@ref) and [`ModelingToolkit.get_guesses`](@ref). @@ -2283,6 +2302,13 @@ ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input 1 => 1 sin(P(t)) => 0.0 ``` + +Additionally, all assertions in the system are optionally logged when they fail. +A new parameter is also added to the system which controls whether the message associated +with each assertion will be logged when the assertion fails. This parameter defaults to +`true` and can be toggled by symbolic indexing with +`ModelingToolkit.ASSERTION_LOG_VARIABLE`. For example, +`prob.ps[ModelingToolkit.ASSERTION_LOG_VARIABLE] = false` will disable logging. """ function debug_system( sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], kw...) @@ -2293,11 +2319,17 @@ function debug_system( error("debug_system(sys) only works on systems with no sub-systems! Consider flattening it with flatten(sys) or structural_simplify(sys) first.") end if has_eqs(sys) - @set! sys.eqs = debug_sub.(equations(sys), Ref(functions); kw...) + eqs = debug_sub.(equations(sys), Ref(functions); kw...) + @set! sys.eqs = eqs + @set! sys.ps = unique!([get_ps(sys); ASSERTION_LOG_VARIABLE]) + @set! sys.defaults = merge(get_defaults(sys), Dict(ASSERTION_LOG_VARIABLE => true)) end if has_observed(sys) @set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...) end + if iscomplete(sys) + sys = complete(sys; split = is_split(sys)) + end return sys end @@ -3036,6 +3068,11 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; kwargs = merge(kwargs, (initialization_eqs = ieqs, guesses = guesses)) end + if has_assertions(basesys) + kwargs = merge( + kwargs, (; assertions = merge(get_assertions(basesys), get_assertions(sys)))) + end + return T(args...; kwargs...) end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 8d0720a1d2..ea55dce388 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -168,6 +168,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] : [eq.rhs for eq in eqs] + if !isempty(assertions(sys)) + rhss[end] += unwrap(get_assertions_expr(sys)) + end + # TODO: add an optional check on the ordering of observed equations u = dvs p = reorder_parameters(sys, ps) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index c735b52c37..96fd8534ae 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -137,6 +137,11 @@ struct ODESystem <: AbstractODESystem """ parameter_dependencies::Vector{Equation} """ + Mapping of conditions which should be true throughout the solution process to corresponding error + messages. These will be added to the equations when calling `debug_system`. + """ + assertions::Dict{BasicSymbolic, String} + """ Metadata for the system, to be used by downstream packages. """ metadata::Any @@ -190,7 +195,7 @@ struct ODESystem <: AbstractODESystem jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching, initializesystem, initialization_eqs, schedule, connector_type, preface, cevents, - devents, parameter_dependencies, + devents, parameter_dependencies, assertions = Dict{BasicSymbolic, String}(), metadata = nothing, gui_metadata = nothing, is_dde = false, tstops = [], tearing_state = nothing, substitutions = nothing, complete = false, index_cache = nothing, @@ -210,7 +215,7 @@ struct ODESystem <: AbstractODESystem new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching, initializesystem, initialization_eqs, schedule, connector_type, preface, - cevents, devents, parameter_dependencies, metadata, + cevents, devents, parameter_dependencies, assertions, metadata, gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache, discrete_subsystems, solved_unknowns, split_idxs, parent) end @@ -235,6 +240,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; continuous_events = nothing, discrete_events = nothing, parameter_dependencies = Equation[], + assertions = Dict(), checks = true, metadata = nothing, gui_metadata = nothing, @@ -286,12 +292,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; if is_dde === nothing is_dde = _check_if_dde(deqs, iv′, systems) end + assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions) ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, nothing, initializesystem, initialization_eqs, schedule, connector_type, preface, cont_callbacks, - disc_callbacks, parameter_dependencies, + disc_callbacks, parameter_dependencies, assertions, metadata, gui_metadata, is_dde, tstops, checks = checks) end @@ -364,6 +371,7 @@ function flatten(sys::ODESystem, noeqs = false) name = nameof(sys), description = description(sys), initialization_eqs = initialization_equations(sys), + assertions = assertions(sys), is_dde = is_dde(sys), tstops = symbolic_tstops(sys), metadata = get_metadata(sys), diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 212e20d743..92a8ddd710 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -126,6 +126,11 @@ struct SDESystem <: AbstractODESystem """ parameter_dependencies::Vector{Equation} """ + Mapping of conditions which should be true throughout the solution process to corresponding error + messages. These will be added to the equations when calling `debug_system`. + """ + assertions::Dict{BasicSymbolic, String} + """ Metadata for the system, to be used by downstream packages. """ metadata::Any @@ -159,7 +164,9 @@ struct SDESystem <: AbstractODESystem function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, - cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing, + cevents, devents, parameter_dependencies, assertions = Dict{ + BasicSymbolic, Nothing}, + metadata = nothing, gui_metadata = nothing, complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false, is_dde = false, isscheduled = false; @@ -185,9 +192,8 @@ struct SDESystem <: AbstractODESystem new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents, - devents, - parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise, - is_dde, isscheduled) + devents, parameter_dependencies, assertions, metadata, gui_metadata, complete, + index_cache, parent, is_scalar_noise, is_dde, isscheduled) end end @@ -209,6 +215,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv continuous_events = nothing, discrete_events = nothing, parameter_dependencies = Equation[], + assertions = Dict{BasicSymbolic, String}(), metadata = nothing, gui_metadata = nothing, complete = false, @@ -261,11 +268,12 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv if is_dde === nothing is_dde = _check_if_dde(deqs, iv′, systems) end + assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions) SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, - cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata, + cont_callbacks, disc_callbacks, parameter_dependencies, assertions, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks) end @@ -378,6 +386,7 @@ function ODESystem(sys::SDESystem) newsys = ODESystem(neweqs, get_iv(sys), unknowns(sys), parameters(sys); parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys), continuous_events = continuous_events(sys), discrete_events = discrete_events(sys), + assertions = assertions(sys), name = nameof(sys), description = description(sys), metadata = get_metadata(sys)) @set newsys.parent = sys end diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 9c8c272c5c..f8630f2d20 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -165,7 +165,7 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal return SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs, get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys); name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys), defaults = defaults(sys), - parameter_dependencies = parameter_dependencies(sys), + parameter_dependencies = parameter_dependencies(sys), assertions = assertions(sys), guesses = guesses(sys), initialization_eqs = initialization_equations(sys)) end end diff --git a/test/debugging.jl b/test/debugging.jl new file mode 100644 index 0000000000..aad36eb995 --- /dev/null +++ b/test/debugging.jl @@ -0,0 +1,50 @@ +using ModelingToolkit, OrdinaryDiffEq, StochasticDiffEq, SymbolicIndexingInterface +using ModelingToolkit: t_nounits as t, D_nounits as D, ASSERTION_LOG_VARIABLE + +@variables x(t) +@brownian a +@named inner_ode = ODESystem(D(x) ~ -sqrt(x), t; assertions = [(x > 0) => "ohno"]) +@named inner_sde = System([D(x) ~ -sqrt(x) + a], t; assertions = [(x > 0) => "ohno"]) +sys_ode = structural_simplify(inner_ode) +sys_sde = structural_simplify(inner_sde) + +@testset "assertions are present in generated `f`" begin + @testset "$(typeof(sys))" for (Problem, sys, alg) in [ + (ODEProblem, sys_ode, Tsit5()), (SDEProblem, sys_sde, ImplicitEM())] + @test !is_parameter(sys, ASSERTION_LOG_VARIABLE) + prob = Problem(sys, [x => 0.1], (0.0, 5.0)) + sol = solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + @test isnan(prob.f.f([0.0], prob.p, sol.t[end])[1]) + end +end + +@testset "`debug_system` adds logging" begin + @testset "$(typeof(sys))" for (Problem, sys, alg) in [ + (ODEProblem, sys_ode, Tsit5()), (SDEProblem, sys_sde, ImplicitEM())] + dsys = debug_system(sys; functions = []) + @test is_parameter(dsys, ASSERTION_LOG_VARIABLE) + prob = Problem(dsys, [x => 0.1], (0.0, 5.0)) + sol = solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + prob.ps[ASSERTION_LOG_VARIABLE] = true + sol = @test_logs (:error, r"ohno") match_mode=:any solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + end +end + +@testset "Hierarchical system" begin + @testset "$(typeof(inner))" for (ctor, Problem, inner, alg) in [ + (ODESystem, ODEProblem, inner_ode, Tsit5()), + (System, SDEProblem, inner_sde, ImplicitEM())] + @mtkbuild outer = ctor(Equation[], t; systems = [inner]) + dsys = debug_system(outer; functions = []) + @test is_parameter(dsys, ASSERTION_LOG_VARIABLE) + prob = Problem(dsys, [inner.x => 0.1], (0.0, 5.0)) + sol = solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + prob.ps[ASSERTION_LOG_VARIABLE] = true + sol = @test_logs (:error, r"ohno") match_mode=:any solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 52875fdae5..9537b1b44e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -92,6 +92,7 @@ end @safetestset "IfLifting Test" include("if_lifting.jl") @safetestset "Analysis Points Test" include("analysis_points.jl") @safetestset "Causal Variables Connection Test" include("causal_variables_connection.jl") + @safetestset "Debugging Test" include("debugging.jl") end end