Skip to content

Commit a0cd75b

Browse files
Merge pull request #317 from isaacsas/jumpsystems
JumpSystems for constant and variable rate jumps
2 parents 94a5f4e + 3968efd commit a0cd75b

File tree

5 files changed

+239
-30
lines changed

5 files changed

+239
-30
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "3.1.1"
66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
9+
DiffEqJump = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12"
910
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1011
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1112
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/ModelingToolkit.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using StaticArrays, LinearAlgebra, SparseArrays
55
using Latexify, Unitful, ArrayInterface
66
using MacroTools
77
using UnPack: @unpack
8+
using DiffEqJump
89

910
using Base.Threads
1011
import MacroTools: splitdef, combinedef, postwalk, striplines
@@ -86,6 +87,8 @@ include("systems/diffeqs/first_order_transform.jl")
8687
include("systems/diffeqs/modelingtoolkitize.jl")
8788
include("systems/diffeqs/validation.jl")
8889

90+
include("systems/jumps/jumpsystem.jl")
91+
8992
include("systems/nonlinear/nonlinearsystem.jl")
9093

9194
include("systems/optimization/optimizationsystem.jl")
@@ -99,7 +102,8 @@ include("build_function.jl")
99102

100103
export ODESystem, ODEFunction
101104
export SDESystem, SDEFunction
102-
export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem
105+
export JumpSystem
106+
export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem, JumpProblem
103107
export NonlinearSystem, OptimizationSystem
104108
export ode_order_lowering
105109
export PDESystem

src/build_function.jl

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,50 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
5656
_build_function(target,args...;kwargs...)
5757
end
5858

59+
function addheader(ex, fargs, iip; X=gensym(:MTIIPVar))
60+
if iip
61+
wrappedex = :(
62+
($X,$(fargs.args...)) -> begin
63+
$ex
64+
nothing
65+
end
66+
)
67+
else
68+
wrappedex = :(
69+
($(fargs.args...),) -> begin
70+
$ex
71+
end
72+
)
73+
end
74+
wrappedex
75+
end
76+
77+
function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar))
78+
integrator = gensym(:MTKIntegrator)
79+
if iip
80+
wrappedex = :(
81+
$integrator -> begin
82+
($X,$(fargs.args...)) = (($integrator).u,($integrator).u,($integrator).p,($integrator).t)
83+
$ex
84+
nothing
85+
end
86+
)
87+
else
88+
wrappedex = :(
89+
$integrator -> begin
90+
($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t)
91+
$ex
92+
end
93+
)
94+
end
95+
wrappedex
96+
end
97+
5998
# Scalar output
6099
function _build_function(target::JuliaTarget, op::Operation, args...;
61100
conv = simplified_expr, expression = Val{true},
62101
checkbounds = false, constructor=nothing,
63-
linenumbers = true)
102+
linenumbers = true, headerfun=addheader)
64103

65104
argnames = [gensym(:MTKArg) for i in 1:length(args)]
66105
arg_pairs = map(vars_to_pairs,zip(argnames,args))
@@ -74,13 +113,8 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
74113
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
75114

76115
fargs = Expr(:tuple,argnames...)
77-
78-
oop_ex = :(
79-
($(fargs.args...),) -> begin
80-
$bounds_block
81-
end
82-
)
83-
116+
oop_ex = headerfun(bounds_block, fargs, false)
117+
84118
if !linenumbers
85119
oop_ex = striplines(oop_ex)
86120
end
@@ -95,8 +129,8 @@ end
95129
function _build_function(target::JuliaTarget, rhss, args...;
96130
conv = simplified_expr, expression = Val{true},
97131
checkbounds = false, constructor=nothing,
98-
linenumbers = false, multithread=false)
99-
132+
linenumbers = false, multithread=false,
133+
headerfun=addheader, outputidxs=nothing)
100134
argnames = [gensym(:MTKArg) for i in 1:length(args)]
101135
arg_pairs = map(vars_to_pairs,zip(argnames,args))
102136
ls = reduce(vcat,first.(arg_pairs))
@@ -106,6 +140,8 @@ function _build_function(target::JuliaTarget, rhss, args...;
106140
fname = gensym(:ModelingToolkitFunction)
107141
fargs = Expr(:tuple,argnames...)
108142

143+
144+
oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i])
109145
X = gensym(:MTIIPVar)
110146
if eltype(eltype(rhss)) <: AbstractArray # Array of arrays of arrays
111147
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)],init=Expr[])) for (i,rhsel) enumerate(rhss)],init=Expr[])
@@ -118,7 +154,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
118154
elseif rhss isa SparseMatrixCSC
119155
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss.nzval)]
120156
else
121-
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
157+
ip_sys_exprs = [:($X[$(oidx(i))] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
122158
end
123159

124160
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
@@ -165,26 +201,20 @@ function _build_function(target::JuliaTarget, rhss, args...;
165201
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
166202
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
167203

168-
oop_ex = :(
169-
($(fargs.args...),) -> begin
170-
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
171-
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
172-
return $arr_bounds_block
173-
else
174-
X = $bounds_block
175-
construct = $_constructor
176-
return construct(X)
177-
end
178-
end
179-
)
180-
181-
iip_ex = :(
182-
($X,$(fargs.args...)) -> begin
183-
$ip_bounds_block
184-
nothing
185-
end
204+
oop_body_block = :(
205+
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
206+
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
207+
return $arr_bounds_block
208+
else
209+
X = $bounds_block
210+
construct = $_constructor
211+
return construct(X)
212+
end
186213
)
187214

215+
oop_ex = headerfun(oop_body_block, fargs, false)
216+
iip_ex = headerfun(ip_bounds_block, fargs, true; X=X)
217+
188218
if !linenumbers
189219
oop_ex = striplines(oop_ex)
190220
iip_ex = striplines(iip_ex)

src/systems/jumps/jumpsystem.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump}
2+
3+
struct JumpSystem <: AbstractSystem
4+
eqs::Vector{JumpType}
5+
iv::Variable
6+
states::Vector{Variable}
7+
ps::Vector{Variable}
8+
name::Symbol
9+
systems::Vector{JumpSystem}
10+
end
11+
12+
function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
13+
name = gensym(:JumpSystem))
14+
JumpSystem(eqs, iv, convert.(Variable, states), convert.(Variable, ps), name, systems)
15+
end
16+
17+
18+
19+
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
20+
independent_variable(js),
21+
expression=Val{false})
22+
23+
generate_affect_function(js, affect, outputidxs) = build_function(affect, states(js),
24+
parameters(js),
25+
independent_variable(js),
26+
expression=Val{false},
27+
headerfun=add_integrator_header,
28+
outputidxs=outputidxs)[2]
29+
function assemble_vrj(js, vrj, statetoid)
30+
rate = generate_rate_function(js, vrj.rate)
31+
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!)
32+
outputidxs = ((statetoid[var] for var in outputvars)...,)
33+
affect = generate_affect_function(js, vrj.affect!, outputidxs)
34+
VariableRateJump(rate, affect)
35+
end
36+
37+
function assemble_crj(js, crj, statetoid)
38+
rate = generate_rate_function(js, crj.rate)
39+
outputvars = (convert(Variable,affect.lhs) for affect in crj.affect!)
40+
outputidxs = ((statetoid[var] for var in outputvars)...,)
41+
affect = generate_affect_function(js, crj.affect!, outputidxs)
42+
ConstantRateJump(rate, affect)
43+
end
44+
45+
"""
46+
```julia
47+
function DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
48+
```
49+
50+
Generates a JumpProblem from a JumpSystem.
51+
"""
52+
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
53+
vrjs = Vector{VariableRateJump}()
54+
crjs = Vector{ConstantRateJump}()
55+
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
56+
for j in equations(js)
57+
if j isa ConstantRateJump
58+
push!(crjs, assemble_crj(js, j, statetoid))
59+
elseif j isa VariableRateJump
60+
push!(vrjs, assemble_vrj(js, j, statetoid))
61+
else
62+
(j isa MassActionJump) && error("Generation of JumpProblems with MassActionJumps is not yet supported.")
63+
end
64+
end
65+
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
66+
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, nothing)
67+
JumpProblem(prob, aggregator, jset)
68+
end

test/jumpsystem.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
using ModelingToolkit, DiffEqBase, DiffEqJump, Test, LinearAlgebra
2+
MT = ModelingToolkit
3+
4+
# basic MT SIR model with tweaks
5+
@parameters β γ t
6+
@variables S I R
7+
rate₁ = β*S*I
8+
affect₁ = [S ~ S - 1, I ~ I + 1]
9+
rate₂ = γ*I+t
10+
affect₂ = [I ~ I - 1, R ~ R + 1]
11+
j₁ = ConstantRateJump(rate₁,affect₁)
12+
j₂ = VariableRateJump(rate₂,affect₂)
13+
js = JumpSystem([j₁,j₂], t, [S,I,R], [β,γ])
14+
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
15+
mtjump1 = MT.assemble_crj(js, j₁, statetoid)
16+
mtjump2 = MT.assemble_vrj(js, j₂, statetoid)
17+
18+
# doc version
19+
rate1(u,p,t) = (0.1/1000.0)*u[1]*u[2]
20+
function affect1!(integrator)
21+
integrator.u[1] -= 1
22+
integrator.u[2] += 1
23+
end
24+
jump1 = ConstantRateJump(rate1,affect1!)
25+
rate2(u,p,t) = 0.01u[2]+t
26+
function affect2!(integrator)
27+
integrator.u[2] -= 1
28+
integrator.u[3] += 1
29+
end
30+
jump2 = VariableRateJump(rate2,affect2!)
31+
32+
# test crjs
33+
u = [100, 9, 5]
34+
p = (0.1/1000,0.01)
35+
tf = 1.0
36+
mutable struct TestInt{U,V,T}
37+
u::U
38+
p::V
39+
t::T
40+
end
41+
mtintegrator = TestInt(u,p,tf)
42+
integrator = TestInt(u,p,tf)
43+
@test abs(mtjump1.rate(u,p,tf) - jump1.rate(u,p,tf)) < 10*eps()
44+
@test abs(mtjump2.rate(u,p,tf) - jump2.rate(u,p,tf)) < 10*eps()
45+
mtjump1.affect!(mtintegrator)
46+
jump1.affect!(integrator)
47+
@test all(integrator.u .== mtintegrator.u)
48+
mtintegrator.u .= u; integrator.u .= u
49+
mtjump2.affect!(mtintegrator)
50+
jump2.affect!(integrator)
51+
@test all(integrator.u .== mtintegrator.u)
52+
53+
# test MT can make and solve a jump problem
54+
rate₃ = γ*I
55+
affect₃ = [I ~ I - 1, R ~ R + 1]
56+
j₃ = ConstantRateJump(rate₃,affect₃)
57+
js2 = JumpSystem([j₁,j₃], t, [S,I,R], [β,γ])
58+
u₀ = [999,1,0]; p = (0.1/1000,0.01); tspan = (0.,250.)
59+
dprob = DiscreteProblem(u₀,tspan,p)
60+
jprob = JumpProblem(js2, dprob, Direct(), save_positions=(false,false))
61+
Nsims = 10000
62+
function getmean(jprob,Nsims)
63+
m = 0.0
64+
for i = 1:Nsims
65+
sol = solve(jprob, SSAStepper())
66+
m += sol[end,end]
67+
end
68+
m/Nsims
69+
end
70+
m = getmean(jprob,Nsims)
71+
72+
#test the MT JumpProblem rates/affects are correct
73+
rate2(u,p,t) = 0.01u[2]
74+
jump2 = ConstantRateJump(rate2,affect2!)
75+
mtjumps = jprob.discrete_jump_aggregation
76+
@test abs(mtjumps.rates[1](u,p,tf) - jump1.rate(u,p,tf)) < 10*eps()
77+
@test abs(mtjumps.rates[2](u,p,tf) - jump2.rate(u,p,tf)) < 10*eps()
78+
mtjumps.affects![1](mtintegrator)
79+
jump1.affect!(integrator)
80+
@test all(integrator.u .== mtintegrator.u)
81+
mtintegrator.u .= u; integrator.u .= u
82+
mtjumps.affects![2](mtintegrator)
83+
jump2.affect!(integrator)
84+
@test all(integrator.u .== mtintegrator.u)
85+
86+
# direct vers
87+
p = (0.1/1000,0.01)
88+
prob = DiscreteProblem([999,1,0],(0.0,250.0),p)
89+
r1(u,p,t) = (0.1/1000.0)*u[1]*u[2]
90+
function a1!(integrator)
91+
integrator.u[1] -= 1
92+
integrator.u[2] += 1
93+
end
94+
j1 = ConstantRateJump(r1,a1!)
95+
r2(u,p,t) = 0.01u[2]
96+
function a2!(integrator)
97+
integrator.u[2] -= 1
98+
integrator.u[3] += 1
99+
end
100+
j2 = ConstantRateJump(r2,a2!)
101+
jset = JumpSet((),(j1,j2),nothing,nothing)
102+
jprob = JumpProblem(prob,Direct(),jset, save_positions=(false,false))
103+
m2 = getmean(jprob,Nsims)
104+
105+
# test JumpSystem solution agrees with direct version
106+
@test abs(m-m2) ./ m < .01

0 commit comments

Comments
 (0)