Skip to content

Commit 9cf859f

Browse files
Merge pull request #3181 from isaacsas/add_odes_to_jumpsys
Support ODEs as equations in JumpSystem
2 parents bc63b47 + 0ec1fa3 commit 9cf859f

File tree

3 files changed

+170
-32
lines changed

3 files changed

+170
-32
lines changed

src/systems/callbacks.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#################################### system operations #####################################
2-
get_continuous_events(sys::AbstractSystem) = SymbolicContinuousCallback[]
3-
get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events)
42
has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
3+
function get_continuous_events(sys::AbstractSystem)
4+
has_continuous_events(sys) || return SymbolicContinuousCallback[]
5+
getfield(sys, :continuous_events)
6+
end
57

68
has_discrete_events(sys::AbstractSystem) = isdefined(sys, :discrete_events)
79
function get_discrete_events(sys::AbstractSystem)
@@ -676,8 +678,8 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
676678
end
677679
end
678680

679-
function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
680-
ps = parameters(sys); kwargs...)
681+
function generate_rootfinding_callback(sys::AbstractTimeDependentSystem,
682+
dvs = unknowns(sys), ps = parameters(sys); kwargs...)
681683
cbs = continuous_events(sys)
682684
isempty(cbs) && return nothing
683685
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
@@ -687,7 +689,7 @@ Generate a single rootfinding callback; this happens if there is only one equati
687689
generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback.
688690
"""
689691
function generate_single_rootfinding_callback(
690-
eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
692+
eq, cb, sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
691693
ps = parameters(sys); kwargs...)
692694
if !isequal(eq.lhs, 0)
693695
eq = 0 ~ eq.lhs - eq.rhs
@@ -729,7 +731,7 @@ function generate_single_rootfinding_callback(
729731
end
730732

731733
function generate_vector_rootfinding_callback(
732-
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
734+
cbs, sys::AbstractTimeDependentSystem, dvs = unknowns(sys),
733735
ps = parameters(sys); rootfind = SciMLBase.RightRootFind,
734736
reinitialization = SciMLBase.CheckInit(), kwargs...)
735737
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
@@ -841,7 +843,7 @@ end
841843
"""
842844
Compile a single continuous callback affect function(s).
843845
"""
844-
function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
846+
function compile_affect_fn(cb, sys::AbstractTimeDependentSystem, dvs, ps, kwargs)
845847
eq_aff = affects(cb)
846848
eq_neg_aff = affect_negs(cb)
847849
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
@@ -858,8 +860,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
858860
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
859861
end
860862

861-
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
862-
ps = parameters(sys); kwargs...)
863+
function generate_rootfinding_callback(cbs, sys::AbstractTimeDependentSystem,
864+
dvs = unknowns(sys), ps = parameters(sys); kwargs...)
863865
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
864866
num_eqs = length.(eqs)
865867
total_eqs = sum(num_eqs)
@@ -1053,12 +1055,12 @@ merge_cb(x, ::Nothing) = x
10531055
merge_cb(x, y) = CallbackSet(x, y)
10541056

10551057
function process_events(sys; callback = nothing, kwargs...)
1056-
if has_continuous_events(sys)
1058+
if has_continuous_events(sys) && !isempty(continuous_events(sys))
10571059
contin_cb = generate_rootfinding_callback(sys; kwargs...)
10581060
else
10591061
contin_cb = nothing
10601062
end
1061-
if has_discrete_events(sys)
1063+
if has_discrete_events(sys) && !isempty(discrete_events(sys))
10621064
discrete_cb = generate_discrete_callbacks(sys; kwargs...)
10631065
else
10641066
discrete_cb = nothing

src/systems/jumps/jumpsystem.jl

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
8888
"""
8989
connector_type::Any
9090
"""
91+
A `Vector{SymbolicContinuousCallback}` that model events.
92+
The integrator will use root finding to guarantee that it steps at each zero crossing.
93+
"""
94+
continuous_events::Vector{SymbolicContinuousCallback}
95+
"""
9196
A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic
9297
analog to `SciMLBase.DiscreteCallback` that executes an affect when a given condition is
9398
true at the end of an integration step. Note, one must make sure to call
@@ -120,8 +125,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
120125

121126
function JumpSystem{U}(
122127
tag, ap::U, iv, unknowns, ps, var_to_name, observed, name, description,
123-
systems,
124-
defaults, connector_type, devents, parameter_dependencies,
128+
systems, defaults, connector_type, cevents, devents, parameter_dependencies,
125129
metadata = nothing, gui_metadata = nothing,
126130
complete = false, index_cache = nothing, isscheduled = false;
127131
checks::Union{Bool, Int} = true) where {U <: ArrayPartition}
@@ -136,8 +140,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
136140
end
137141
new{U}(tag, ap, iv, unknowns, ps, var_to_name,
138142
observed, name, description, systems, defaults,
139-
connector_type, devents, parameter_dependencies, metadata, gui_metadata,
140-
complete, index_cache, isscheduled)
143+
connector_type, cevents, devents, parameter_dependencies, metadata,
144+
gui_metadata, complete, index_cache, isscheduled)
141145
end
142146
end
143147
function JumpSystem(tag, ap, iv, states, ps, var_to_name, args...; kwargs...)
@@ -194,26 +198,28 @@ function JumpSystem(eqs, iv, unknowns, ps;
194198
# this and the treatment of continuous events are the only part
195199
# unique to JumpSystems
196200
eqs = scalarize.(eqs)
197-
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
201+
ap = ArrayPartition(
202+
MassActionJump[], ConstantRateJump[], VariableRateJump[], Equation[])
198203
for eq in eqs
199204
if eq isa MassActionJump
200205
push!(ap.x[1], eq)
201206
elseif eq isa ConstantRateJump
202207
push!(ap.x[2], eq)
203208
elseif eq isa VariableRateJump
204209
push!(ap.x[3], eq)
210+
elseif eq isa Equation
211+
push!(ap.x[4], eq)
205212
else
206-
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
213+
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, VariableRateJumps, or Equations.")
207214
end
208215
end
209216

210-
(continuous_events === nothing) ||
211-
error("JumpSystems currently only support discrete events.")
217+
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
212218
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
213219

214220
JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
215221
ap, iv′, us′, ps′, var_to_name, observed, name, description, systems,
216-
defaults, connector_type, disc_callbacks, parameter_dependencies,
222+
defaults, connector_type, cont_callbacks, disc_callbacks, parameter_dependencies,
217223
metadata, gui_metadata, checks = checks)
218224
end
219225

@@ -245,6 +251,7 @@ end
245251
has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
246252
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
247253
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])
254+
has_equations(js::JumpSystem) = !isempty(equations(js).x[4])
248255

249256
function generate_rate_function(js::JumpSystem, rate)
250257
consts = collect_constants(rate)
@@ -281,7 +288,7 @@ function assemble_vrj(
281288
outputidxs = [unknowntoid[var] for var in outputvars]
282289
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
283290
eval_expression, eval_module)
284-
VariableRateJump(rate, affect)
291+
VariableRateJump(rate, affect; save_positions = vrj.save_positions)
285292
end
286293

287294
function assemble_vrj_expr(js, vrj, unknowntoid)
@@ -390,6 +397,11 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
390397
if !iscomplete(sys)
391398
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
392399
end
400+
401+
if has_equations(sys) || (!isempty(continuous_events(sys)))
402+
error("The passed in JumpSystem contains `Equation`s or continuous events, please use a problem type that supports these features, such as ODEProblem.")
403+
end
404+
393405
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
394406
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
395407
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
@@ -478,14 +490,24 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
478490
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
479491
end
480492

481-
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
482-
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
483-
484-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
485-
486-
f = (du, u, p, t) -> (du .= 0; nothing)
487-
df = ODEFunction(f; sys, observed = observedfun)
488-
ODEProblem(df, u0, tspan, p; kwargs...)
493+
# forward everything to be an ODESystem but the jumps and discrete events
494+
if has_equations(sys)
495+
osys = ODESystem(equations(sys).x[4], get_iv(sys), unknowns(sys), parameters(sys);
496+
observed = observed(sys), name = nameof(sys), description = description(sys),
497+
systems = get_systems(sys), defaults = defaults(sys),
498+
parameter_dependencies = parameter_dependencies(sys),
499+
metadata = get_metadata(sys), gui_metadata = get_gui_metadata(sys))
500+
osys = complete(osys)
501+
return ODEProblem(osys, u0map, tspan, parammap; check_length = false, kwargs...)
502+
else
503+
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
504+
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
505+
check_length = false)
506+
f = (du, u, p, t) -> (du .= 0; nothing)
507+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
508+
df = ODEFunction(f; sys, observed = observedfun)
509+
return ODEProblem(df, u0, tspan, p; kwargs...)
510+
end
489511
end
490512

491513
"""
@@ -521,8 +543,11 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob,
521543
for j in eqs.x[2]]
522544
vrjs = VariableRateJump[assemble_vrj(js, j, unknowntoid; eval_expression, eval_module)
523545
for j in eqs.x[3]]
524-
((prob isa DiscreteProblem) && !isempty(vrjs)) &&
525-
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
546+
if prob isa DiscreteProblem
547+
if (!isempty(vrjs) || has_equations(js) || !isempty(continuous_events(js)))
548+
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps, coupled differential equations, or continuous events.")
549+
end
550+
end
526551
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)
527552

528553
# dep graphs are only for constant rate jumps

test/jumpsystem.jl

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra
2-
using Random, StableRNGs
2+
using Random, StableRNGs, NonlinearSolve
33
using OrdinaryDiffEq
44
using ModelingToolkit: t_nounits as t, D_nounits as D
55
MT = ModelingToolkit
@@ -422,3 +422,114 @@ let
422422
@test issetequal(us, [x5])
423423
@test issetequal(ps, [p5])
424424
end
425+
426+
# PDMP test
427+
let
428+
seed = 1111
429+
Random.seed!(rng, seed)
430+
@variables X(t) Y(t)
431+
@parameters k1 k2
432+
vrj1 = VariableRateJump(k1 * X, [X ~ X - 1]; save_positions = (false, false))
433+
vrj2 = VariableRateJump(k1, [Y ~ Y + 1]; save_positions = (false, false))
434+
eqs = [D(X) ~ k2, D(Y) ~ -k2 / 10 * Y]
435+
@named jsys = JumpSystem([vrj1, vrj2, eqs[1], eqs[2]], t, [X, Y], [k1, k2])
436+
jsys = complete(jsys)
437+
X0 = 0.0
438+
Y0 = 3.0
439+
u0 = [X => X0, Y => Y0]
440+
k1val = 1.0
441+
k2val = 20.0
442+
p = [k1 => k1val, k2 => k2val]
443+
tspan = (0.0, 10.0)
444+
oprob = ODEProblem(jsys, u0, tspan, p)
445+
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))
446+
447+
times = range(0.0, tspan[2], length = 100)
448+
Nsims = 4000
449+
Xv = zeros(length(times))
450+
Yv = zeros(length(times))
451+
for n in 1:Nsims
452+
sol = solve(jprob, Tsit5(); saveat = times, seed)
453+
Xv .+= sol[1, :]
454+
Yv .+= sol[2, :]
455+
seed += 1
456+
end
457+
Xv ./= Nsims
458+
Yv ./= Nsims
459+
460+
Xact(t) = X0 * exp(-k1val * t) + (k2val / k1val) * (1 - exp(-k1val * t))
461+
function Yact(t)
462+
Y0 * exp(-k2val / 10 * t) + (k1val / (k2val / 10)) * (1 - exp(-k2val / 10 * t))
463+
end
464+
@test all(abs.(Xv .- Xact.(times)) .<= 0.05 .* Xv)
465+
@test all(abs.(Yv .- Yact.(times)) .<= 0.1 .* Yv)
466+
end
467+
468+
# that mixes ODEs and jump types, and then contin events
469+
let
470+
seed = 1111
471+
Random.seed!(rng, seed)
472+
@variables X(t) Y(t)
473+
@parameters α β
474+
vrj = VariableRateJump* X, [X ~ X - 1]; save_positions = (false, false))
475+
crj = ConstantRateJump* Y, [Y ~ Y - 1])
476+
maj = MassActionJump(α, [0 => 1], [Y => 1])
477+
eqs = [D(X) ~ α * (1 + Y)]
478+
@named jsys = JumpSystem([maj, crj, vrj, eqs[1]], t, [X, Y], [α, β])
479+
jsys = complete(jsys)
480+
p == 6.0, β = 2.0, X₀ = 2.0, Y₀ = 1.0)
481+
u0map = [X => p.X₀, Y => p.Y₀]
482+
pmap ==> p.α, β => p.β]
483+
tspan = (0.0, 20.0)
484+
oprob = ODEProblem(jsys, u0map, tspan, pmap)
485+
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))
486+
times = range(0.0, tspan[2], length = 100)
487+
Nsims = 4000
488+
Xv = zeros(length(times))
489+
Yv = zeros(length(times))
490+
for n in 1:Nsims
491+
sol = solve(jprob, Tsit5(); saveat = times, seed)
492+
Xv .+= sol[1, :]
493+
Yv .+= sol[2, :]
494+
seed += 1
495+
end
496+
Xv ./= Nsims
497+
Yv ./= Nsims
498+
499+
function Yf(t, p)
500+
local α, β, X₀, Y₀ = p
501+
return/ β) + (Y₀ - α / β) * exp(-β * t)
502+
end
503+
function Xf(t, p)
504+
local α, β, X₀, Y₀ = p
505+
return/ β) +^2 / β^2) + α * (Y₀ - α / β) * t * exp(-β * t) +
506+
(X₀ - α / β - α^2 / β^2) * exp(-β * t)
507+
end
508+
Xact = [Xf(t, p) for t in times]
509+
Yact = [Yf(t, p) for t in times]
510+
@test all(abs.(Xv .- Xact) .<= 0.05 .* Xv)
511+
@test all(abs.(Yv .- Yact) .<= 0.05 .* Yv)
512+
513+
function affect!(integ, u, p, ctx)
514+
savevalues!(integ, true)
515+
terminate!(integ)
516+
nothing
517+
end
518+
cevents = [t ~ 0.2] => (affect!, [], [], [], nothing)
519+
@named jsys = JumpSystem([maj, crj, vrj, eqs[1]], t, [X, Y], [α, β];
520+
continuous_events = cevents)
521+
jsys = complete(jsys)
522+
tspan = (0.0, 200.0)
523+
oprob = ODEProblem(jsys, u0map, tspan, pmap)
524+
jprob = JumpProblem(jsys, oprob; rng, save_positions = (false, false))
525+
Xsamp = 0.0
526+
Nsims = 4000
527+
for n in 1:Nsims
528+
sol = solve(jprob, Tsit5(); saveat = tspan[2], seed)
529+
@test sol.retcode == ReturnCode.Terminated
530+
Xsamp += sol[1, end]
531+
seed += 1
532+
end
533+
Xsamp /= Nsims
534+
@test abs(Xsamp - Xf(0.2, p) < 0.05 * Xf(0.2, p))
535+
end

0 commit comments

Comments
 (0)