Skip to content

Commit 05a215a

Browse files
committed
add events for SDESystems
1 parent 74d3f9f commit 05a215a

File tree

4 files changed

+136
-16
lines changed

4 files changed

+136
-16
lines changed

src/systems/callbacks.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,17 +241,14 @@ end
241241
################################# compilation functions ####################################
242242

243243
# handles ensuring that affect! functions work with integrator arguments
244-
function add_integrator_header(out = :u)
245-
integrator = gensym(:MTKIntegrator)
246-
244+
function add_integrator_header(integrator = gensym(:MTKIntegrator), out = :u)
247245
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
248246
expr.body),
249247
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [out, :u, :p, :t])], [],
250248
expr.body)
251249
end
252250

253-
function condition_header()
254-
integrator = gensym(:MTKIntegrator)
251+
function condition_header(integrator = gensym(:MTKIntegrator))
255252
expr -> Func([expr.args[1], expr.args[2],
256253
DestructuredArgs(expr.args[3:end], integrator, inds = [:p])], [], expr.body)
257254
end
@@ -296,7 +293,8 @@ Notes
296293
- `kwargs` are passed through to `Symbolics.build_function`.
297294
"""
298295
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
299-
expression = Val{true}, checkvars = true, kwargs...)
296+
expression = Val{true}, checkvars = true,
297+
postprocess_affect_expr! = nothing, kwargs...)
300298
if isempty(eqs)
301299
if expression == Val{true}
302300
return :((args...) -> ())
@@ -337,10 +335,17 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
337335
p = ps
338336
end
339337
t = get_iv(sys)
340-
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression,
341-
wrap_code = add_integrator_header(outvar),
338+
integ = gensym(:MTKIntegrator)
339+
getexpr = (postprocess_affect_expr! === nothing) ? expression : Val{true}
340+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = getexpr,
341+
wrap_code = add_integrator_header(integ, outvar),
342342
outputidxs = update_inds,
343343
kwargs...)
344+
# applied user-provided function to the generated expression
345+
if postprocess_affect_expr! !== nothing
346+
postprocess_affect_expr!(rf_ip, integ)
347+
(expression == Val{false}) && (return @RuntimeGeneratedFunction(rf_ip))
348+
end
344349
rf_ip
345350
end
346351
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,30 @@ struct SDESystem <: AbstractODESystem
8585
type: type of the system
8686
"""
8787
connector_type::Any
88+
"""
89+
continuous_events: A `Vector{SymbolicContinuousCallback}` that model events.
90+
The integrator will use root finding to guarantee that it steps at each zero crossing.
91+
"""
92+
continuous_events::Vector{SymbolicContinuousCallback}
93+
"""
94+
discrete_events: A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
95+
analog to `SciMLBase.DiscreteCallback` that exectues an affect when a given condition is
96+
true at the end of an integration step.
97+
"""
98+
discrete_events::Vector{SymbolicDiscreteCallback}
8899

89100
function SDESystem(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
90-
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type;
91-
checks::Bool = true)
101+
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
102+
cevents, devents; checks::Bool = true)
92103
if checks
93104
check_variables(dvs, iv)
94105
check_parameters(ps, iv)
95106
check_equations(deqs, iv)
107+
check_equations(equations(cevents), iv)
96108
all_dimensionless([dvs; ps; iv]) || check_units(deqs, neqs)
97109
end
98110
new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac,
99-
Wfact, Wfact_t, name, systems, defaults, connector_type)
111+
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents)
100112
end
101113
end
102114

@@ -109,7 +121,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
109121
defaults = _merge(Dict(default_u0), Dict(default_p)),
110122
name = nothing,
111123
connector_type = nothing,
112-
checks = true)
124+
checks = true,
125+
continuous_events = nothing,
126+
discrete_events = nothing)
113127
name === nothing &&
114128
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
115129
deqs = scalarize(deqs)
@@ -139,9 +153,16 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
139153
ctrl_jac = RefValue{Any}(EMPTY_JAC)
140154
Wfact = RefValue(EMPTY_JAC)
141155
Wfact_t = RefValue(EMPTY_JAC)
156+
sysnames = nameof.(systems)
157+
if length(unique(sysnames)) != length(sysnames)
158+
throw(ArgumentError("System names must be unique."))
159+
end
160+
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
161+
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
162+
142163
SDESystem(deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
143164
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
144-
checks = checks)
165+
cont_callbacks, disc_callbacks; checks = checks)
145166
end
146167

147168
function SDESystem(sys::ODESystem, neqs; kwargs...)
@@ -509,6 +530,7 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map, tspan,
509530
kwargs...) where {iip}
510531
f, u0, p = process_DEProblem(SDEFunction{iip}, sys, u0map, parammap; check_length,
511532
kwargs...)
533+
cbs = process_events(sys; kwargs...)
512534
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
513535

514536
noiseeqs = get_noiseeqs(sys)
@@ -521,8 +543,8 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map, tspan,
521543
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
522544
end
523545

524-
SDEProblem{iip}(f, f.g, u0, tspan, p; noise_rate_prototype = noise_rate_prototype,
525-
kwargs...)
546+
SDEProblem{iip}(f, f.g, u0, tspan, p; callback = cbs,
547+
noise_rate_prototype = noise_rate_prototype, kwargs...)
526548
end
527549

528550
function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...)

src/systems/jumps/jumpsystem.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
const JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump}
22

3+
# assumes iip
4+
function add_jump_resetting!(expr, integrator)
5+
if expr isa Symbol
6+
error("Error, encountered a symbol. This should not happen.")
7+
end
8+
9+
if (expr.head == :function)
10+
add_jump_resetting!(expr.args[end], integrator)
11+
else
12+
if expr.args[end] == :nothing
13+
expr.args[end] = :(reset_aggregated_jumps!($integrator))
14+
push!(expr.args, :nothing)
15+
else
16+
add_jump_resetting!(expr.args[end], integrator)
17+
end
18+
end
19+
20+
nothing
21+
end
22+
323
"""
424
$(TYPEDEF)
525

test/root_equations.jl

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, OrdinaryDiffEq, Test
1+
using ModelingToolkit, OrdinaryDiffEq, StochasticDiffEq, Test
22
using ModelingToolkit: SymbolicContinuousCallback, SymbolicContinuousCallbacks, NULL_AFFECT,
33
get_callback
44

@@ -407,3 +407,76 @@ let
407407
sol = testsol(osys7, u0, p, (0.0, 10.0); tstops = [1.0, 2.0], skipparamtest = true)
408408
@test isapprox(sol(10.0)[1], .1; atol=1e-10, rtol=1e-10)
409409
end
410+
411+
let
412+
function testsol(ssys, u0, p, tspan; tstops = Float64[], skipparamtest = false, kwargs...)
413+
sprob = SDEProblem(ssys, u0, tspan, p; kwargs...)
414+
sol = solve(sprob, RI5(); tstops = tstops, abstol = 1e-10, reltol = 1e-10)
415+
@test isapprox(sol(1.0000000001)[1] - sol(0.999999999)[1], 1.0; rtol = 1e-4)
416+
!skipparamtest && (@test sprob.p[1] == 1.0)
417+
@test isapprox(sol(4.0)[1], 2 * exp(-2.0), atol=1e-4)
418+
sol
419+
end
420+
421+
@parameters k t1 t2
422+
@variables t A(t) B(t)
423+
424+
cond1 = (t == t1)
425+
affect1 = [A ~ A + 1]
426+
cb1 = cond1 => affect1
427+
cond2 = (t == t2)
428+
affect2 = [k ~ 1.0]
429+
cb2 = cond2 => affect2
430+
431+
∂ₜ = Differential(t)
432+
eqs = [∂ₜ(A) ~ -k * A]
433+
@named ssys = SDESystem(eqs, Equation[], t, [A], [k, t1, t2], discrete_events = [cb1, cb2])
434+
u0 = [A => 1.0]
435+
p = [k => 0.0, t1 => 1.0, t2 => 2.0]
436+
tspan = (0.0, 4.0)
437+
testsol(ssys, u0, p, tspan; tstops = [1.0, 2.0])
438+
439+
cond1a = (t == t1)
440+
affect1a = [A ~ A + 1, B ~ A]
441+
cb1a = cond1a => affect1a
442+
@named ssys1 = SDESystem(eqs, Equation[], t, [A, B], [k, t1, t2], discrete_events = [cb1a, cb2])
443+
u0′ = [A => 1.0, B => 0.0]
444+
sol = testsol(ssys1, u0′, p, tspan; tstops = [1.0, 2.0], check_length = false)
445+
@test sol(1.0000001, idxs = 2) == 2.0
446+
447+
# same as above - but with set-time event syntax
448+
cb1‵ = [1.0] => affect1 # needs to be a Vector for the event to happen only once
449+
cb2‵ = [2.0] => affect2
450+
@named ssys‵ = SDESystem(eqs, Equation[], t, [A], [k], discrete_events = [cb1‵, cb2‵])
451+
testsol(ssys‵, u0, p, tspan)
452+
453+
# mixing discrete affects
454+
@named ssys3 = SDESystem(eqs, Equation[], t, [A], [k, t1, t2], discrete_events = [cb1, cb2‵])
455+
testsol(ssys3, u0, p, tspan; tstops = [1.0])
456+
457+
# mixing with a func affect
458+
function affect!(integrator, u, p, ctx)
459+
integrator.p[p.k] = 1.0
460+
nothing
461+
end
462+
cb2‵‵ = [2.0] => (affect!, [], [k], nothing)
463+
@named ssys4 = SDESystem(eqs, Equation[], t, [A], [k, t1], discrete_events = [cb1, cb2‵‵])
464+
oprob4 = ODEProblem(ssys4, u0, tspan, p)
465+
testsol(ssys4, u0, p, tspan; tstops = [1.0])
466+
467+
# mixing with symbolic condition in the func affect
468+
cb2‵‵‵ = (t == t2) => (affect!, [], [k], nothing)
469+
@named ssys5 = SDESystem(eqs, Equation[], t, [A], [k, t1, t2], discrete_events = [cb1, cb2‵‵‵])
470+
testsol(ssys5, u0, p, tspan; tstops = [1.0, 2.0])
471+
@named ssys6 = SDESystem(eqs, Equation[], t, [A], [k, t1, t2], discrete_events = [cb2‵‵‵, cb1])
472+
testsol(ssys6, u0, p, tspan; tstops = [1.0, 2.0])
473+
474+
# mix a continuous event too
475+
cond3 = A ~ .1
476+
affect3 = [k ~ 0.0]
477+
cb3 = cond3 => affect3
478+
@named ssys7 = SDESystem(eqs, Equation[], t, [A], [k, t1, t2], discrete_events = [cb1, cb2‵‵‵],
479+
continuous_events = [cb3])
480+
sol = testsol(ssys7, u0, p, (0.0, 10.0); tstops = [1.0, 2.0], skipparamtest = true)
481+
@test isapprox(sol(10.0)[1], .1; atol=1e-10, rtol=1e-10)
482+
end

0 commit comments

Comments
 (0)