Skip to content

Commit 7e7eacf

Browse files
committed
add discrete_events field for ODEs
1 parent 707a803 commit 7e7eacf

File tree

6 files changed

+37
-16
lines changed

6 files changed

+37
-16
lines changed

src/structural_transformation/codegen.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra
22

3-
using ModelingToolkit: isdifferenceeq, has_continuous_events, process_events
3+
using ModelingToolkit: isdifferenceeq, process_events
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

src/systems/abstractsystem.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,7 @@ function Base.hash(sys::AbstractSystem, s::UInt)
10391039
end
10401040
s = foldr(hash, get_observed(sys), init = s)
10411041
s = foldr(hash, get_continuous_events(sys), init = s)
1042+
s = foldr(hash, get_discrete_events(sys), init = s)
10421043
s = hash(independent_variables(sys), s)
10431044
return s
10441045
end
@@ -1066,16 +1067,17 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam
10661067
sts = union(get_states(basesys), get_states(sys))
10671068
ps = union(get_ps(basesys), get_ps(sys))
10681069
obs = union(get_observed(basesys), get_observed(sys))
1069-
evs = union(get_continuous_events(basesys), get_continuous_events(sys))
1070+
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
1071+
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
10701072
defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys`
10711073
syss = union(get_systems(basesys), get_systems(sys))
10721074

10731075
if length(ivs) == 0
10741076
T(eqs, sts, ps, observed = obs, defaults = defs, name = name, systems = syss,
1075-
continuous_events = evs)
1077+
continuous_events = cevs, discrete_events = devs)
10761078
elseif length(ivs) == 1
10771079
T(eqs, ivs[1], sts, ps, observed = obs, defaults = defs, name = name,
1078-
systems = syss, continuous_events = evs)
1080+
systems = syss, continuous_events = cevs, discrete_events = devs)
10791081
end
10801082
end
10811083

src/systems/callbacks.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
55

66
has_discrete_events(sys::AbstractSystem) = isdefined(sys, :discrete_events)
77
function get_discrete_events(sys::AbstractSystem)
8-
has_discrete_events(sys) ||
9-
error("Systems of type $(typeof(sys)) do not support discrete events.")
8+
has_discrete_events(sys) || return SymbolicDiscreteCallback[]
109
getfield(sys, :discrete_events)
1110
end
1211

@@ -62,10 +61,10 @@ affect_equations(cb::SymbolicContinuousCallback) = cb.affect
6261
function affect_equations(cbs::Vector{SymbolicContinuousCallback})
6362
reduce(vcat, [affect_equations(cb) for cb in cbs])
6463
end
65-
namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback = SymbolicContinuousCallback(namespace_equation.(equations(cb),
66-
(s,)),
67-
namespace_equation.(affect_equations(cb),
68-
(s,)))
64+
function namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
65+
SymbolicContinuousCallback(namespace_equation.(equations(cb), (s,)),
66+
namespace_equation.(affect_equations(cb), (s,)))
67+
end
6968

7069
function continuous_events(sys::AbstractSystem)
7170
obs = get_continuous_events(sys)
@@ -124,6 +123,11 @@ function namespace_equation(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCa
124123
namespace_equation.(affect_equations(cb), Ref(s)))
125124
end
126125

126+
SymbolicDiscreteCallbacks(cb::SymbolicDiscreteCallback) = [cb]
127+
SymbolicDiscreteCallbacks(cbs::Vector{<:SymbolicDiscreteCallback}) = cbs
128+
SymbolicDiscreteCallbacks(cbs::Vector) = SymbolicDiscreteCallback.(cbs)
129+
SymbolicDiscreteCallbacks(::Nothing) = SymbolicDiscreteCallback[]
130+
127131
function discrete_events(sys::AbstractSystem)
128132
obs = get_discrete_events(sys)
129133
systems = get_systems(sys)

src/systems/diffeqs/odesystem.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ struct ODESystem <: AbstractODESystem
102102
"""
103103
continuous_events::Vector{SymbolicContinuousCallback}
104104
"""
105+
discrete_events: A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
106+
analog to `SciMLBase.DiscreteCallback` that exectues an affect when a given condition is
107+
true at the end of an integration step.
108+
"""
109+
discrete_events::Vector{SymbolicDiscreteCallback}
110+
"""
105111
tearing_state: cache for intermediate tearing state
106112
"""
107113
tearing_state::Any
@@ -112,19 +118,20 @@ struct ODESystem <: AbstractODESystem
112118

113119
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
114120
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
115-
torn_matching, connector_type, connections, preface, events,
116-
tearing_state = nothing, substitutions = nothing;
121+
torn_matching, connector_type, connections, preface, cevents,
122+
devents, tearing_state = nothing, substitutions = nothing;
117123
checks::Bool = true)
118124
if checks
119125
check_variables(dvs, iv)
120126
check_parameters(ps, iv)
121127
check_equations(deqs, iv)
122-
check_equations(equations(events), iv)
128+
check_equations(equations(cevents), iv)
123129
all_dimensionless([dvs; ps; iv]) || check_units(deqs)
124130
end
125131
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
126132
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
127-
connector_type, connections, preface, events, tearing_state, substitutions)
133+
connector_type, connections, preface, cevents, devents, tearing_state,
134+
substitutions)
128135
end
129136
end
130137

@@ -139,6 +146,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
139146
connector_type = nothing,
140147
preface = nothing,
141148
continuous_events = nothing,
149+
discrete_events = nothing,
142150
checks = true)
143151
name === nothing &&
144152
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@@ -172,9 +180,11 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
172180
throw(ArgumentError("System names must be unique."))
173181
end
174182
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
183+
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
175184
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
176185
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
177-
connector_type, nothing, preface, cont_callbacks, checks = checks)
186+
connector_type, nothing, preface, cont_callbacks, disc_callbacks,
187+
checks = checks)
178188
end
179189

180190
function ODESystem(eqs, iv = nothing; kwargs...)
@@ -244,6 +254,7 @@ function flatten(sys::ODESystem, noeqs = false)
244254
parameters(sys),
245255
observed = observed(sys),
246256
continuous_events = continuous_events(sys),
257+
discrete_events = discrete_events(sys),
247258
defaults = defaults(sys),
248259
name = nameof(sys),
249260
checks = false)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,13 @@ function NonlinearSystem(eqs, states, ps;
8383
systems = NonlinearSystem[],
8484
connector_type = nothing,
8585
continuous_events = nothing, # this argument is only required for ODESystems, but is added here for the constructor to accept it without error
86+
discrete_events = nothing, # this argument is only required for ODESystems, but is added here for the constructor to accept it without error
8687
checks = true)
8788
continuous_events === nothing || isempty(continuous_events) ||
8889
throw(ArgumentError("NonlinearSystem does not accept `continuous_events`, you provided $continuous_events"))
90+
discrete_events === nothing || isempty(discrete_events) ||
91+
throw(ArgumentError("NonlinearSystem does not accept `discrete_events`, you provided $discrete_events"))
92+
8993
name === nothing &&
9094
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
9195
# Move things over, but do not touch array expressions

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ println("Last test requires gcc available in the path!")
3939
@testset "Serialization" begin include("serialization.jl") end
4040
@safetestset "print_tree" begin include("print_tree.jl") end
4141
@safetestset "error_handling" begin include("error_handling.jl") end
42-
@safetestset "Callbacks" begin include("root_equations.jl") end
42+
@safetestset "root_equations" begin include("root_equations.jl") end
4343
@safetestset "state_selection" begin include("state_selection.jl") end
4444
@safetestset "Modelingtoolkitize Test" begin include("modelingtoolkitize.jl") end
4545
@safetestset "ControlSystem Test" begin include("controlsystem.jl") end

0 commit comments

Comments
 (0)