Skip to content

Commit 90e487e

Browse files
committed
add discrete events to more systems
1 parent 05a215a commit 90e487e

File tree

5 files changed

+109
-19
lines changed

5 files changed

+109
-19
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
9191
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
9292
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9393
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
94+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
9495
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
9596
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
9697
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
@@ -99,4 +100,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
99100
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
100101

101102
[targets]
102-
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
103+
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

src/systems/callbacks.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,11 @@ function compile_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
440440
compile_user_affect(affect, sys, dvs, ps; kwargs...)
441441
end
442442

443-
function generate_timed_callback(cb, sys, dvs, ps; kwargs...)
443+
function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing,
444+
kwargs...)
444445
cond = condition(cb)
445446
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
446-
kwargs...)
447+
postprocess_affect_expr!, kwargs...)
447448
if cond isa AbstractVector
448449
# Preset Time
449450
return PresetTimeCallback(cond, as)
@@ -453,13 +454,15 @@ function generate_timed_callback(cb, sys, dvs, ps; kwargs...)
453454
end
454455
end
455456

456-
function generate_discrete_callback(cb, sys, dvs, ps; kwargs...)
457+
function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing,
458+
kwargs...)
457459
if is_timed_condition(cb)
458-
return generate_timed_callback(cb, sys, dvs, ps, kwargs...)
460+
return generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr!,
461+
kwargs...)
459462
else
460463
c = compile_condition(cb, sys, dvs, ps; expression = Val{false}, kwargs...)
461464
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
462-
kwargs...)
465+
postprocess_affect_expr!, kwargs...)
463466
return DiscreteCallback(c, as)
464467
end
465468
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,10 @@ symbolically calculating numerical enhancements.
527527
function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map, tspan,
528528
parammap = DiffEqBase.NullParameters();
529529
sparsenoise = nothing, check_length = true,
530-
kwargs...) where {iip}
530+
callback = nothing, kwargs...) where {iip}
531531
f, u0, p = process_DEProblem(SDEFunction{iip}, sys, u0map, parammap; check_length,
532532
kwargs...)
533-
cbs = process_events(sys; kwargs...)
533+
cbs = process_events(sys; callback)
534534
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
535535

536536
noiseeqs = get_noiseeqs(sys)

src/systems/jumps/jumpsystem.jl

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
const JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump}
22

3+
# modifies the expression representating an affect function to
4+
# call reset_aggregated_jumps!(integrator).
35
# assumes iip
4-
function add_jump_resetting!(expr, integrator)
6+
function _reset_aggregator!(expr, integrator)
57
if expr isa Symbol
68
error("Error, encountered a symbol. This should not happen.")
79
end
810

911
if (expr.head == :function)
10-
add_jump_resetting!(expr.args[end], integrator)
12+
_reset_aggregator!(expr.args[end], integrator)
1113
else
1214
if expr.args[end] == :nothing
1315
expr.args[end] = :(reset_aggregated_jumps!($integrator))
1416
push!(expr.args, :nothing)
1517
else
16-
add_jump_resetting!(expr.args[end], integrator)
18+
_reset_aggregator!(expr.args[end], integrator)
1719
end
1820
end
1921

@@ -73,16 +75,25 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
7375
type: type of the system
7476
"""
7577
connector_type::Any
78+
"""
79+
discrete_events: A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
80+
analog to `SciMLBase.DiscreteCallback` that exectues an affect when a given condition is
81+
true at the end of an integration step. *Note, one must make sure to call
82+
`reset_aggregated_jumps!(integrator)` if using a custom affect function that changes any
83+
state value or parameter.*
84+
"""
85+
discrete_events::Vector{SymbolicDiscreteCallback}
86+
7687
function JumpSystem{U}(ap::U, iv, states, ps, var_to_name, observed, name, systems,
77-
defaults, connector_type;
88+
defaults, connector_type, devents;
7889
checks::Bool = true) where {U <: ArrayPartition}
7990
if checks
8091
check_variables(states, iv)
8192
check_parameters(ps, iv)
8293
all_dimensionless([states; ps; iv]) || check_units(ap, iv)
8394
end
8495
new{U}(ap, iv, states, ps, var_to_name, observed, name, systems, defaults,
85-
connector_type)
96+
connector_type, devents)
8697
end
8798
end
8899

@@ -95,6 +106,8 @@ function JumpSystem(eqs, iv, states, ps;
95106
name = nothing,
96107
connector_type = nothing,
97108
checks = true,
109+
continuous_events = nothing,
110+
discrete_events = nothing,
98111
kwargs...)
99112
name === nothing &&
100113
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@@ -127,9 +140,12 @@ function JumpSystem(eqs, iv, states, ps;
127140
process_variables!(var_to_name, defaults, states)
128141
process_variables!(var_to_name, defaults, ps)
129142
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
143+
(continuous_events === nothing) ||
144+
error("JumpSystems currently only support discrete events.")
145+
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
130146

131147
JumpSystem{typeof(ap)}(ap, value(iv), states, ps, var_to_name, observed, name, systems,
132-
defaults, connector_type, checks = checks)
148+
defaults, connector_type, disc_callbacks; checks = checks)
133149
end
134150

135151
function generate_rate_function(js::JumpSystem, rate)
@@ -325,7 +341,8 @@ jprob = JumpProblem(js, dprob, Direct())
325341
sol = solve(jprob, SSAStepper())
326342
```
327343
"""
328-
function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
344+
function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; callback = nothing,
345+
kwargs...)
329346
statetoid = Dict(value(state) => i for (i, state) in enumerate(states(js)))
330347
eqs = equations(js)
331348
invttype = prob.tspan[1] === nothing ? Float64 : typeof(1 / prob.tspan[2])
@@ -354,9 +371,12 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
354371
jtoj = nothing
355372
end
356373

374+
# handle events, making sure to reset aggregators in the generated affect functions
375+
cbs = process_events(js; callback, postprocess_affect_expr! = _reset_aggregator!)
376+
357377
JumpProblem(prob, aggregator, jset; dep_graph = jtoj, vartojumps_map = vtoj,
358-
jumptovars_map = jtov,
359-
scale_rates = false, nocopy = true, kwargs...)
378+
jumptovars_map = jtov, scale_rates = false, nocopy = true,
379+
callback = cbs, kwargs...)
360380
end
361381

362382
### Functions to determine which states a jump depends on

test/root_equations.jl

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
using ModelingToolkit, OrdinaryDiffEq, StochasticDiffEq, Test
1+
using ModelingToolkit, OrdinaryDiffEq, StochasticDiffEq, JumpProcesses, Test
22
using ModelingToolkit: SymbolicContinuousCallback, SymbolicContinuousCallbacks, NULL_AFFECT,
33
get_callback
4+
using StableRNGs
5+
rng = StableRNG(12345)
46

57
@parameters t
68
@variables x(t) = 0
@@ -461,7 +463,6 @@ let
461463
end
462464
cb2‵‵ = [2.0] => (affect!, [], [k], nothing)
463465
@named ssys4 = SDESystem(eqs, Equation[], t, [A], [k, t1], discrete_events = [cb1, cb2‵‵])
464-
oprob4 = ODEProblem(ssys4, u0, tspan, p)
465466
testsol(ssys4, u0, p, tspan; tstops = [1.0])
466467

467468
# mixing with symbolic condition in the func affect
@@ -480,3 +481,68 @@ let
480481
sol = testsol(ssys7, u0, p, (0.0, 10.0); tstops = [1.0, 2.0], skipparamtest = true)
481482
@test isapprox(sol(10.0)[1], .1; atol=1e-10, rtol=1e-10)
482483
end
484+
485+
let rng = rng
486+
function testsol(jsys, u0, p, tspan; tstops = Float64[], skipparamtest = false,
487+
N = 40000, kwargs...)
488+
dprob = DiscreteProblem(jsys, u0, tspan, p)
489+
jprob = JumpProblem(jsys, dprob, Direct(); kwargs...)
490+
sol = solve(jprob, SSAStepper(); tstops = tstops)
491+
@test (sol(1.000000000001)[1] - sol(0.99999999999)[1]) == 1
492+
!skipparamtest && (@test dprob.p[1] == 1.0)
493+
@test sol(40.0)[1] == 0
494+
sol
495+
end
496+
497+
@parameters k t1 t2
498+
@variables t A(t) B(t)
499+
500+
cond1 = (t == t1)
501+
affect1 = [A ~ A + 1]
502+
cb1 = cond1 => affect1
503+
cond2 = (t == t2)
504+
affect2 = [k ~ 1.0]
505+
cb2 = cond2 => affect2
506+
507+
eqs = [MassActionJump(k, [A => 1], [A => -1])]
508+
@named jsys = JumpSystem(eqs, t, [A], [k, t1, t2], discrete_events = [cb1, cb2])
509+
u0 = [A => 1]
510+
p = [k => 0.0, t1 => 1.0, t2 => 2.0]
511+
tspan = (0.0, 40.0)
512+
testsol(jsys, u0, p, tspan; tstops = [1.0, 2.0], rng)
513+
514+
cond1a = (t == t1)
515+
affect1a = [A ~ A + 1, B ~ A]
516+
cb1a = cond1a => affect1a
517+
@named jsys1 = JumpSystem(eqs, t, [A, B], [k, t1, t2], discrete_events = [cb1a, cb2])
518+
u0′ = [A => 1, B => 0]
519+
sol = testsol(jsys1, u0′, p, tspan; tstops = [1.0, 2.0], check_length = false, rng)
520+
@test sol(1.000000001, idxs = B) == 2
521+
522+
# same as above - but with set-time event syntax
523+
cb1‵ = [1.0] => affect1 # needs to be a Vector for the event to happen only once
524+
cb2‵ = [2.0] => affect2
525+
@named jsys‵ = JumpSystem(eqs, t, [A], [k], discrete_events = [cb1‵, cb2‵])
526+
testsol(jsys‵, u0, [p[1]], tspan; rng)
527+
528+
# mixing discrete affects
529+
@named jsys3 = JumpSystem(eqs, t, [A], [k, t1, t2], discrete_events = [cb1, cb2‵])
530+
testsol(jsys3, u0, p, tspan; tstops = [1.0], rng)
531+
532+
# mixing with a func affect
533+
function affect!(integrator, u, p, ctx)
534+
integrator.p[p.k] = 1.0
535+
reset_aggregated_jumps!(integrator)
536+
nothing
537+
end
538+
cb2‵‵ = [2.0] => (affect!, [], [k], nothing)
539+
@named jsys4 = JumpSystem(eqs, t, [A], [k, t1], discrete_events = [cb1, cb2‵‵])
540+
testsol(jsys4, u0, p, tspan; tstops = [1.0], rng)
541+
542+
# mixing with symbolic condition in the func affect
543+
cb2‵‵‵ = (t == t2) => (affect!, [], [k], nothing)
544+
@named jsys5 = JumpSystem(eqs, t, [A], [k, t1, t2], discrete_events = [cb1, cb2‵‵‵])
545+
testsol(jsys5, u0, p, tspan; tstops = [1.0, 2.0], rng)
546+
@named jsys6 = JumpSystem(eqs, t, [A], [k, t1, t2], discrete_events = [cb2‵‵‵, cb1])
547+
testsol(jsys6, u0, p, tspan; tstops = [1.0, 2.0], rng)
548+
end

0 commit comments

Comments
 (0)