Skip to content

Commit 5f8cda3

Browse files
feat: add assertions field to ODESystem
1 parent b107633 commit 5f8cda3

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

src/systems/abstractsystem.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ for prop in [:eqs
983983
:gui_metadata
984984
:discrete_subsystems
985985
:parameter_dependencies
986+
:assertions
986987
:solved_unknowns
987988
:split_idxs
988989
:parent
@@ -1468,6 +1469,24 @@ end
14681469
"""
14691470
$(TYPEDSIGNATURES)
14701471
1472+
Get the assertions for a system `sys` and its subsystems.
1473+
"""
1474+
function assertions(sys::AbstractSystem)
1475+
has_assertions(sys) || return Dict{BasicSymbolic, String}()
1476+
1477+
asserts = get_assertions(sys)
1478+
systems = get_systems(sys)
1479+
namespaced_asserts = mapreduce(
1480+
merge!, systems; init = Dict{BasicSymbolic, String}()) do subsys
1481+
Dict{BasicSymbolic, String}(namespace_expr(k, subsys) => v
1482+
for (k, v) in assertions(subsys))
1483+
end
1484+
return merge(asserts, namespaced_asserts)
1485+
end
1486+
1487+
"""
1488+
$(TYPEDSIGNATURES)
1489+
14711490
Get the guesses for variables in the initialization system of the system `sys` and its subsystems.
14721491
14731492
See also [`initialization_equations`](@ref) and [`ModelingToolkit.get_guesses`](@ref).
@@ -3036,6 +3055,11 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
30363055
kwargs = merge(kwargs, (initialization_eqs = ieqs, guesses = guesses))
30373056
end
30383057

3058+
if has_assertions(basesys)
3059+
kwargs = merge(
3060+
kwargs, (; assertions = merge(get_assertions(basesys), get_assertions(sys))))
3061+
end
3062+
30393063
return T(args...; kwargs...)
30403064
end
30413065

src/systems/diffeqs/odesystem.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ struct ODESystem <: AbstractODESystem
137137
"""
138138
parameter_dependencies::Vector{Equation}
139139
"""
140+
Mapping of conditions which should be true throughout the solve to corresponding error
141+
messages. These will be added to the equations when calling `debug_system`.
142+
"""
143+
assertions::Dict{BasicSymbolic, String}
144+
"""
140145
Metadata for the system, to be used by downstream packages.
141146
"""
142147
metadata::Any
@@ -190,7 +195,7 @@ struct ODESystem <: AbstractODESystem
190195
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
191196
torn_matching, initializesystem, initialization_eqs, schedule,
192197
connector_type, preface, cevents,
193-
devents, parameter_dependencies,
198+
devents, parameter_dependencies, assertions = Dict{BasicSymbolic, String}(),
194199
metadata = nothing, gui_metadata = nothing, is_dde = false,
195200
tstops = [], tearing_state = nothing,
196201
substitutions = nothing, complete = false, index_cache = nothing,
@@ -210,7 +215,7 @@ struct ODESystem <: AbstractODESystem
210215
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
211216
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
212217
initializesystem, initialization_eqs, schedule, connector_type, preface,
213-
cevents, devents, parameter_dependencies, metadata,
218+
cevents, devents, parameter_dependencies, assertions, metadata,
214219
gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache,
215220
discrete_subsystems, solved_unknowns, split_idxs, parent)
216221
end
@@ -235,6 +240,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
235240
continuous_events = nothing,
236241
discrete_events = nothing,
237242
parameter_dependencies = Equation[],
243+
assertions = Dict(),
238244
checks = true,
239245
metadata = nothing,
240246
gui_metadata = nothing,
@@ -286,12 +292,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
286292
if is_dde === nothing
287293
is_dde = _check_if_dde(deqs, iv′, systems)
288294
end
295+
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
289296
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
290297
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
291298
ctrl_jac, Wfact, Wfact_t, name, description, systems,
292299
defaults, guesses, nothing, initializesystem,
293300
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
294-
disc_callbacks, parameter_dependencies,
301+
disc_callbacks, parameter_dependencies, assertions,
295302
metadata, gui_metadata, is_dde, tstops, checks = checks)
296303
end
297304

@@ -364,6 +371,7 @@ function flatten(sys::ODESystem, noeqs = false)
364371
name = nameof(sys),
365372
description = description(sys),
366373
initialization_eqs = initialization_equations(sys),
374+
assertions = assertions(sys),
367375
is_dde = is_dde(sys),
368376
tstops = symbolic_tstops(sys),
369377
metadata = get_metadata(sys),

0 commit comments

Comments
 (0)