Skip to content

Commit b5e6dd9

Browse files
feat: add assertions field to SDESystem
1 parent aa6f340 commit b5e6dd9

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ struct SDESystem <: AbstractODESystem
126126
"""
127127
parameter_dependencies::Vector{Equation}
128128
"""
129+
Mapping of conditions which should be true throughout the solve to corresponding error
130+
messages. These will be added to the equations when calling `debug_system`.
131+
"""
132+
assertions::Dict{BasicSymbolic, String}
133+
"""
129134
Metadata for the system, to be used by downstream packages.
130135
"""
131136
metadata::Any
@@ -159,7 +164,9 @@ struct SDESystem <: AbstractODESystem
159164
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
160165
tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
161166
guesses, initializesystem, initialization_eqs, connector_type,
162-
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
167+
cevents, devents, parameter_dependencies, assertions = Dict{
168+
BasicSymbolic, Nothing},
169+
metadata = nothing, gui_metadata = nothing,
163170
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false,
164171
is_dde = false,
165172
isscheduled = false;
@@ -185,9 +192,8 @@ struct SDESystem <: AbstractODESystem
185192
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
186193
ctrl_jac, Wfact, Wfact_t, name, description, systems,
187194
defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents,
188-
devents,
189-
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise,
190-
is_dde, isscheduled)
195+
devents, parameter_dependencies, assertions, metadata, gui_metadata, complete,
196+
index_cache, parent, is_scalar_noise, is_dde, isscheduled)
191197
end
192198
end
193199

@@ -209,6 +215,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
209215
continuous_events = nothing,
210216
discrete_events = nothing,
211217
parameter_dependencies = Equation[],
218+
assertions = Dict{BasicSymbolic, String}(),
212219
metadata = nothing,
213220
gui_metadata = nothing,
214221
complete = false,
@@ -261,11 +268,12 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
261268
if is_dde === nothing
262269
is_dde = _check_if_dde(deqs, iv′, systems)
263270
end
271+
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
264272
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
265273
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
266274
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
267275
initializesystem, initialization_eqs, connector_type,
268-
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
276+
cont_callbacks, disc_callbacks, parameter_dependencies, assertions, metadata, gui_metadata,
269277
complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks)
270278
end
271279

@@ -378,6 +386,7 @@ function ODESystem(sys::SDESystem)
378386
newsys = ODESystem(neweqs, get_iv(sys), unknowns(sys), parameters(sys);
379387
parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys),
380388
continuous_events = continuous_events(sys), discrete_events = discrete_events(sys),
389+
assertions = assertions(sys),
381390
name = nameof(sys), description = description(sys), metadata = get_metadata(sys))
382391
@set newsys.parent = sys
383392
end

src/systems/systems.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
165165
return SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs,
166166
get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys);
167167
name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys), defaults = defaults(sys),
168-
parameter_dependencies = parameter_dependencies(sys),
168+
parameter_dependencies = parameter_dependencies(sys), assertions = assertions(sys),
169169
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
170170
end
171171
end

0 commit comments

Comments
 (0)