Skip to content

Commit 46778e3

Browse files
authored
Merge pull request #1714 from isaacsas/disc-cb-three
Add events to SDESystems and JumpSystems
2 parents 01186a4 + c037622 commit 46778e3

File tree

5 files changed

+294
-42
lines changed

5 files changed

+294
-42
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ ForwardDiff = "0.10.3"
5959
Graphs = "1.5.2"
6060
IfElse = "0.1"
6161
JuliaFormatter = "1"
62-
JumpProcesses = "9"
62+
JumpProcesses = "9.1"
6363
LabelledArrays = "1.3"
6464
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15"
6565
MacroTools = "0.5"
@@ -91,6 +91,7 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
9191
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
9292
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9393
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
94+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
9495
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
9596
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
9697
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
@@ -99,4 +100,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
99100
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
100101

101102
[targets]
102-
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
103+
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

src/systems/callbacks.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,17 +241,14 @@ end
241241
################################# compilation functions ####################################
242242

243243
# handles ensuring that affect! functions work with integrator arguments
244-
function add_integrator_header(out = :u)
245-
integrator = gensym(:MTKIntegrator)
246-
244+
function add_integrator_header(integrator = gensym(:MTKIntegrator), out = :u)
247245
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
248246
expr.body),
249247
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [out, :u, :p, :t])], [],
250248
expr.body)
251249
end
252250

253-
function condition_header()
254-
integrator = gensym(:MTKIntegrator)
251+
function condition_header(integrator = gensym(:MTKIntegrator))
255252
expr -> Func([expr.args[1], expr.args[2],
256253
DestructuredArgs(expr.args[3:end], integrator, inds = [:p])], [], expr.body)
257254
end
@@ -296,7 +293,8 @@ Notes
296293
- `kwargs` are passed through to `Symbolics.build_function`.
297294
"""
298295
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
299-
expression = Val{true}, checkvars = true, kwargs...)
296+
expression = Val{true}, checkvars = true,
297+
postprocess_affect_expr! = nothing, kwargs...)
300298
if isempty(eqs)
301299
if expression == Val{true}
302300
return :((args...) -> ())
@@ -337,10 +335,17 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
337335
p = ps
338336
end
339337
t = get_iv(sys)
340-
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression,
341-
wrap_code = add_integrator_header(outvar),
338+
integ = gensym(:MTKIntegrator)
339+
getexpr = (postprocess_affect_expr! === nothing) ? expression : Val{true}
340+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = getexpr,
341+
wrap_code = add_integrator_header(integ, outvar),
342342
outputidxs = update_inds,
343343
kwargs...)
344+
# applied user-provided function to the generated expression
345+
if postprocess_affect_expr! !== nothing
346+
postprocess_affect_expr!(rf_ip, integ)
347+
(expression == Val{false}) && (return @RuntimeGeneratedFunction(rf_ip))
348+
end
344349
rf_ip
345350
end
346351
end
@@ -435,10 +440,11 @@ function compile_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
435440
compile_user_affect(affect, sys, dvs, ps; kwargs...)
436441
end
437442

438-
function generate_timed_callback(cb, sys, dvs, ps; kwargs...)
443+
function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing,
444+
kwargs...)
439445
cond = condition(cb)
440446
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
441-
kwargs...)
447+
postprocess_affect_expr!, kwargs...)
442448
if cond isa AbstractVector
443449
# Preset Time
444450
return PresetTimeCallback(cond, as)
@@ -448,13 +454,15 @@ function generate_timed_callback(cb, sys, dvs, ps; kwargs...)
448454
end
449455
end
450456

451-
function generate_discrete_callback(cb, sys, dvs, ps; kwargs...)
457+
function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing,
458+
kwargs...)
452459
if is_timed_condition(cb)
453-
return generate_timed_callback(cb, sys, dvs, ps, kwargs...)
460+
return generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr!,
461+
kwargs...)
454462
else
455463
c = compile_condition(cb, sys, dvs, ps; expression = Val{false}, kwargs...)
456464
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
457-
kwargs...)
465+
postprocess_affect_expr!, kwargs...)
458466
return DiscreteCallback(c, as)
459467
end
460468
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,30 @@ struct SDESystem <: AbstractODESystem
8585
type: type of the system
8686
"""
8787
connector_type::Any
88+
"""
89+
continuous_events: A `Vector{SymbolicContinuousCallback}` that model events.
90+
The integrator will use root finding to guarantee that it steps at each zero crossing.
91+
"""
92+
continuous_events::Vector{SymbolicContinuousCallback}
93+
"""
94+
discrete_events: A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
95+
analog to `SciMLBase.DiscreteCallback` that exectues an affect when a given condition is
96+
true at the end of an integration step.
97+
"""
98+
discrete_events::Vector{SymbolicDiscreteCallback}
8899

89100
function SDESystem(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
90-
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type;
91-
checks::Bool = true)
101+
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
102+
cevents, devents; checks::Bool = true)
92103
if checks
93104
check_variables(dvs, iv)
94105
check_parameters(ps, iv)
95106
check_equations(deqs, iv)
107+
check_equations(equations(cevents), iv)
96108
all_dimensionless([dvs; ps; iv]) || check_units(deqs, neqs)
97109
end
98110
new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac,
99-
Wfact, Wfact_t, name, systems, defaults, connector_type)
111+
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents)
100112
end
101113
end
102114

@@ -109,7 +121,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
109121
defaults = _merge(Dict(default_u0), Dict(default_p)),
110122
name = nothing,
111123
connector_type = nothing,
112-
checks = true)
124+
checks = true,
125+
continuous_events = nothing,
126+
discrete_events = nothing)
113127
name === nothing &&
114128
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
115129
deqs = scalarize(deqs)
@@ -139,9 +153,12 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
139153
ctrl_jac = RefValue{Any}(EMPTY_JAC)
140154
Wfact = RefValue(EMPTY_JAC)
141155
Wfact_t = RefValue(EMPTY_JAC)
156+
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
157+
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
158+
142159
SDESystem(deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
143160
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
144-
checks = checks)
161+
cont_callbacks, disc_callbacks; checks = checks)
145162
end
146163

147164
function SDESystem(sys::ODESystem, neqs; kwargs...)
@@ -491,9 +508,10 @@ end
491508
function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map, tspan,
492509
parammap = DiffEqBase.NullParameters();
493510
sparsenoise = nothing, check_length = true,
494-
kwargs...) where {iip}
511+
callback = nothing, kwargs...) where {iip}
495512
f, u0, p = process_DEProblem(SDEFunction{iip}, sys, u0map, parammap; check_length,
496513
kwargs...)
514+
cbs = process_events(sys; callback)
497515
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
498516

499517
noiseeqs = get_noiseeqs(sys)
@@ -506,8 +524,8 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map, tspan,
506524
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
507525
end
508526

509-
SDEProblem{iip}(f, f.g, u0, tspan, p; noise_rate_prototype = noise_rate_prototype,
510-
kwargs...)
527+
SDEProblem{iip}(f, f.g, u0, tspan, p; callback = cbs,
528+
noise_rate_prototype = noise_rate_prototype, kwargs...)
511529
end
512530

513531
"""

src/systems/jumps/jumpsystem.jl

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,27 @@
11
const JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump}
22

3+
# modifies the expression representating an affect function to
4+
# call reset_aggregated_jumps!(integrator).
5+
# assumes iip
6+
function _reset_aggregator!(expr, integrator)
7+
if expr isa Symbol
8+
error("Error, encountered a symbol. This should not happen.")
9+
end
10+
11+
if (expr.head == :function)
12+
_reset_aggregator!(expr.args[end], integrator)
13+
else
14+
if expr.args[end] == :nothing
15+
expr.args[end] = :(reset_aggregated_jumps!($integrator))
16+
push!(expr.args, :nothing)
17+
else
18+
_reset_aggregator!(expr.args[end], integrator)
19+
end
20+
end
21+
22+
nothing
23+
end
24+
325
"""
426
$(TYPEDEF)
527
@@ -53,16 +75,25 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
5375
type: type of the system
5476
"""
5577
connector_type::Any
78+
"""
79+
discrete_events: A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
80+
analog to `SciMLBase.DiscreteCallback` that exectues an affect when a given condition is
81+
true at the end of an integration step. *Note, one must make sure to call
82+
`reset_aggregated_jumps!(integrator)` if using a custom affect function that changes any
83+
state value or parameter.*
84+
"""
85+
discrete_events::Vector{SymbolicDiscreteCallback}
86+
5687
function JumpSystem{U}(ap::U, iv, states, ps, var_to_name, observed, name, systems,
57-
defaults, connector_type;
88+
defaults, connector_type, devents;
5889
checks::Bool = true) where {U <: ArrayPartition}
5990
if checks
6091
check_variables(states, iv)
6192
check_parameters(ps, iv)
6293
all_dimensionless([states; ps; iv]) || check_units(ap, iv)
6394
end
6495
new{U}(ap, iv, states, ps, var_to_name, observed, name, systems, defaults,
65-
connector_type)
96+
connector_type, devents)
6697
end
6798
end
6899

@@ -75,6 +106,8 @@ function JumpSystem(eqs, iv, states, ps;
75106
name = nothing,
76107
connector_type = nothing,
77108
checks = true,
109+
continuous_events = nothing,
110+
discrete_events = nothing,
78111
kwargs...)
79112
name === nothing &&
80113
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@@ -107,9 +140,12 @@ function JumpSystem(eqs, iv, states, ps;
107140
process_variables!(var_to_name, defaults, states)
108141
process_variables!(var_to_name, defaults, ps)
109142
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
143+
(continuous_events === nothing) ||
144+
error("JumpSystems currently only support discrete events.")
145+
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
110146

111147
JumpSystem{typeof(ap)}(ap, value(iv), states, ps, var_to_name, observed, name, systems,
112-
defaults, connector_type, checks = checks)
148+
defaults, connector_type, disc_callbacks; checks = checks)
113149
end
114150

115151
function generate_rate_function(js::JumpSystem, rate)
@@ -305,7 +341,8 @@ jprob = JumpProblem(js, dprob, Direct())
305341
sol = solve(jprob, SSAStepper())
306342
```
307343
"""
308-
function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
344+
function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; callback = nothing,
345+
kwargs...)
309346
statetoid = Dict(value(state) => i for (i, state) in enumerate(states(js)))
310347
eqs = equations(js)
311348
invttype = prob.tspan[1] === nothing ? Float64 : typeof(1 / prob.tspan[2])
@@ -334,9 +371,12 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
334371
jtoj = nothing
335372
end
336373

374+
# handle events, making sure to reset aggregators in the generated affect functions
375+
cbs = process_events(js; callback, postprocess_affect_expr! = _reset_aggregator!)
376+
337377
JumpProblem(prob, aggregator, jset; dep_graph = jtoj, vartojumps_map = vtoj,
338-
jumptovars_map = jtov,
339-
scale_rates = false, nocopy = true, kwargs...)
378+
jumptovars_map = jtov, scale_rates = false, nocopy = true,
379+
callback = cbs, kwargs...)
340380
end
341381

342382
### Functions to determine which states a jump depends on

0 commit comments

Comments
 (0)