Skip to content

Commit f65a39d

Browse files
committed
can make JumpProblem, but wrong answers on SIR
1 parent 976c22e commit f65a39d

File tree

3 files changed

+43
-28
lines changed

3 files changed

+43
-28
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using StaticArrays, LinearAlgebra, SparseArrays
55
using Latexify, Unitful, ArrayInterface
66
using MacroTools
77
using UnPack: @unpack
8-
using DiffEqJump: VariableRateJump, ConstantRateJump, MassActionJump
8+
using DiffEqJump
99

1010
using Base.Threads
1111
import MacroTools: splitdef, combinedef, postwalk, striplines
@@ -104,7 +104,7 @@ include("build_function.jl")
104104
export ODESystem, ODEFunction
105105
export SDESystem, SDEFunction
106106
export JumpSystem
107-
export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem
107+
export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem, JumpProblem
108108
export NonlinearSystem, OptimizationSystem
109109
export ode_order_lowering
110110
export PDESystem

src/systems/jumps/jumpsystem.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ generate_affect_function(js, affect) = build_function(affect, states(js),
2525
independent_variable(js),
2626
expression=Val{false},
2727
headerfun=add_integrator_header)[2]
28-
function assemble_vrj(js, vrj)
28+
function assemble_vrj(js, vrj)
2929
rate = generate_rate_function(js, vrj.rate)
3030
affect = generate_affect_function(js, vrj.affect!)
3131
VariableRateJump(rate, affect)
@@ -37,23 +37,26 @@ function assemble_crj(js, crj)
3737
ConstantRateJump(rate, affect)
3838
end
3939

40-
function assemble_maj(maj, states_to_idxs, ps_to_idxs; scale_rate=false)
41-
42-
# mass action scaled_rates need to be a Number, but
43-
# operations are numbers, so can't check the type directly
44-
@assert !isa(maj.scaled_rates, Union{Operation,Variable})
45-
46-
rstype = fieldtypes(eltype(maj.reactant_stoch))[2]
47-
rs = Vector{Pair{valtype(states_to_idxs),rstype}}()
48-
for (spec,stoich) in maj.reactant_stoch
49-
push!(rs, states_to_idxs[spec] => stoich)
40+
"""
41+
```julia
42+
function DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
43+
```
44+
45+
Generates a JumpProblem from a JumpSystem.
46+
"""
47+
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
48+
vrjs = Vector{VariableRateJump}()
49+
crjs = Vector{ConstantRateJump}()
50+
for j in equations(js)
51+
if j isa ConstantRateJump
52+
push!(crjs, assemble_crj(js, j))
53+
elseif j isa VariableRateJump
54+
push!(vrjs, assemble_vrj(js, j))
55+
else
56+
(j isa MassActionJump) && error("Generation of JumpProblems with MassActionJumps is not yet supported.")
57+
end
5058
end
51-
52-
nstype = fieldtypes(eltype(maj.net_stoch))[2]
53-
ns = Vector{Pair{valtype(states_to_idxs),nstype}}()
54-
for (spec,stoich) in maj.net_stoch
55-
push!(ns, states_to_idxs[spec] => stoich)
56-
end
57-
58-
MassActionJump(maj.scaled_rates, rs, ns, scale_rates=scale_rate)
59+
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
60+
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, nothing)
61+
JumpProblem(prob, aggregator, jset)
5962
end

test/jumpsystem.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, DiffEqJump, Test, LinearAlgebra
1+
using ModelingToolkit, DiffEqBase, DiffEqJump, Test, LinearAlgebra
22
MT = ModelingToolkit
33

44
# basic SIR model with tweaks
@@ -12,7 +12,7 @@ j₁ = ConstantRateJump(rate₁,affect₁)
1212
j₂ = VariableRateJump(rate₂,affect₂)
1313
js = JumpSystem([j₁,j₂], t, [S,I,R], [β,γ])
1414
mtjump1 = MT.assemble_crj(js, j₁)
15-
mtjump2 = MT.assemble_crj(js, j₂)
15+
mtjump2 = MT.assemble_vrj(js, j₂)
1616

1717
# doc version
1818
rate1(u,p,t) = (0.1/1000.0)*u[1]*u[2]
@@ -31,16 +31,16 @@ jump2 = VariableRateJump(rate2,affect2!)
3131
# test crjs
3232
u = [100, 9, 5]
3333
p = (0.1/1000,0.01)
34-
t = 1.0
34+
tf = 1.0
3535
mutable struct TestInt
3636
u
3737
p
3838
t
3939
end
40-
mtintegrator = TestInt(u,p,t)
41-
integrator = TestInt(u,p,t)
42-
@test abs(mtjump1.rate(u,p,t) - jump1.rate(u,p,t)) < 10*eps()
43-
@test abs(mtjump2.rate(u,p,t) - jump2.rate(u,p,t)) < 10*eps()
40+
mtintegrator = TestInt(u,p,tf)
41+
integrator = TestInt(u,p,tf)
42+
@test abs(mtjump1.rate(u,p,tf) - jump1.rate(u,p,tf)) < 10*eps()
43+
@test abs(mtjump2.rate(u,p,tf) - jump2.rate(u,p,tf)) < 10*eps()
4444
mtjump1.affect!(mtintegrator)
4545
jump1.affect!(integrator)
4646
@test norm(integrator.u - mtintegrator.u) < 10*eps()
@@ -50,3 +50,15 @@ jump2.affect!(integrator)
5050
@test norm(integrator.u - mtintegrator.u) < 10*eps()
5151

5252

53+
# test can make and solve a jump problem
54+
rate₂ = γ*I
55+
affect₂ = [I ~ I - 1, R ~ R + 1]
56+
j₃ = ConstantRateJump(rate₂,affect₂)
57+
js2 = JumpSystem([j₁,j₃], t, [S,I,R], [β,γ])
58+
u₀ = [999,1,0]; p = (0.1/1000,0.01); tspan = (0.,250.)
59+
dprob = DiscreteProblem(u₀,tspan,p)
60+
jprob = JumpProblem(js2, dprob, Direct())
61+
sol = solve(jprob, SSAStepper())
62+
63+
using Plots
64+
plot(sol)

0 commit comments

Comments
 (0)